mirror of
https://github.com/Control-D-Inc/ctrld.git
synced 2026-02-03 22:18:39 +00:00
Compare commits
318 Commits
v1.3.9
...
release-br
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2e53fa4274 | ||
|
|
d0e66b83d0 | ||
|
|
34fef77ff7 | ||
|
|
7006e967e4 | ||
|
|
f9d026334a | ||
|
|
36d4192c05 | ||
|
|
90eddb8268 | ||
|
|
c13a3c3c17 | ||
|
|
d42a78cba9 | ||
|
|
92f32ba16e | ||
|
|
4c838f6a5e | ||
|
|
adc0e1a51e | ||
|
|
3afdaef6e6 | ||
|
|
ef7432df55 | ||
|
|
fb807d7c37 | ||
|
|
f7c124d99d | ||
|
|
ed826f7a95 | ||
|
|
56f8113bb0 | ||
|
|
a04babbbc3 | ||
|
|
59b98245d3 | ||
|
|
f6be1ab1fb | ||
|
|
54f58cc2e5 | ||
|
|
eb8c5bc3fa | ||
|
|
3bcad10f92 | ||
|
|
d87a0a69c8 | ||
|
|
b7202f8469 | ||
|
|
a084c87370 | ||
|
|
5d87bd07ca | ||
|
|
3412d1f8b9 | ||
|
|
a72ff1e769 | ||
|
|
2c98b2c545 | ||
|
|
4792183c0d | ||
|
|
d88c860cac | ||
|
|
8b605da861 | ||
|
|
7cda5d7646 | ||
|
|
0cd873a88f | ||
|
|
1ff5d1f05a | ||
|
|
954395fa29 | ||
|
|
ea98a59aba | ||
|
|
6971d392b7 | ||
|
|
a2f8313668 | ||
|
|
af05cb2d94 | ||
|
|
5f0b9a24b9 | ||
|
|
37523fdc45 | ||
|
|
ca505f1140 | ||
|
|
a22f0579d5 | ||
|
|
9f656269ac | ||
|
|
13de41d854 | ||
|
|
42ea5f7fed | ||
|
|
af9386568f | ||
|
|
13b15e642d | ||
|
|
d4df2e7f72 | ||
|
|
5b8ed3a72f | ||
|
|
59fe94112d | ||
|
|
0ab51cdad7 | ||
|
|
0a1d6fa4db | ||
|
|
6e10bba7fe | ||
|
|
fc8268b70a | ||
|
|
b5f101f667 | ||
|
|
69b192c6fa | ||
|
|
ec85b1621d | ||
|
|
ddbb0f0db4 | ||
|
|
2996a161cd | ||
|
|
35e2a20019 | ||
|
|
65a300a807 | ||
|
|
a67aea88be | ||
|
|
84d4491a18 | ||
|
|
05d183c94b | ||
|
|
41282d0f51 | ||
|
|
f7fb555c89 | ||
|
|
2e63624f6c | ||
|
|
b2a54db4b5 | ||
|
|
0ef02bc15e | ||
|
|
b18cd7ee83 | ||
|
|
a16b25ad1d | ||
|
|
59ece456b1 | ||
|
|
d5cb327620 | ||
|
|
7a2277bc18 | ||
|
|
c736f4c1e9 | ||
|
|
f0cb810dd6 | ||
|
|
64632fa640 | ||
|
|
b9b9cfcade | ||
|
|
fc527dbdfb | ||
|
|
5641aab5bd | ||
|
|
31517ce750 | ||
|
|
51e58b64a5 | ||
|
|
3ca559e5a4 | ||
|
|
0e3f764299 | ||
|
|
e52402eb0c | ||
|
|
2133f31854 | ||
|
|
a198a5cd65 | ||
|
|
eb2b231bd2 | ||
|
|
7af29cfbc0 | ||
|
|
ce1a165348 | ||
|
|
fd48e6d795 | ||
|
|
d71d1341b6 | ||
|
|
21855df4af | ||
|
|
66e2d3a40a | ||
|
|
26257cf24a | ||
|
|
36a7423634 | ||
|
|
e616091249 | ||
|
|
0948161529 | ||
|
|
ce29b5d217 | ||
|
|
de24fa293e | ||
|
|
6663925c4d | ||
|
|
b9ece6d7b9 | ||
|
|
c4efa1ab97 | ||
|
|
7cea5305e1 | ||
|
|
a20fbf95de | ||
|
|
628c4302aa | ||
|
|
8dc34f8bf5 | ||
|
|
b4faf82f76 | ||
|
|
a983dfaee2 | ||
|
|
62f73bcaa2 | ||
|
|
00e9d2bdd3 | ||
|
|
ace3b1e66e | ||
|
|
d1ea1ba08c | ||
|
|
c06c8aa859 | ||
|
|
0c2cc00c4f | ||
|
|
8d6ea91f35 | ||
|
|
7dfb77228f | ||
|
|
24910f1fa6 | ||
|
|
433a61d2ee | ||
|
|
3937e885f0 | ||
|
|
c651003cc4 | ||
|
|
b7ccfcb8b4 | ||
|
|
a9ed70200b | ||
|
|
c6365e6b74 | ||
|
|
dacc67e50f | ||
|
|
c60cf33af3 | ||
|
|
f27cbe3525 | ||
|
|
2de1b9929a | ||
|
|
8bf654aece | ||
|
|
84376ed719 | ||
|
|
7a136b8874 | ||
|
|
58c0e4f15a | ||
|
|
e0d35d8ba2 | ||
|
|
3b2e48761e | ||
|
|
b27064008e | ||
|
|
1ad63827e1 | ||
|
|
20e61550c2 | ||
|
|
020b814402 | ||
|
|
e578867118 | ||
|
|
46a1039f21 | ||
|
|
cc9e27de5f | ||
|
|
6ab3ab9faf | ||
|
|
e68bfa795a | ||
|
|
e60a92e93e | ||
|
|
62fe14f76b | ||
|
|
a0c5062e3a | ||
|
|
49eb152d02 | ||
|
|
b05056423a | ||
|
|
c7168739c7 | ||
|
|
5b1faf1ce3 | ||
|
|
513a6f9ec7 | ||
|
|
8db6fa4232 | ||
|
|
5036de2602 | ||
|
|
332f8ccc37 | ||
|
|
a582195cec | ||
|
|
9fe36ae984 | ||
|
|
54cb455522 | ||
|
|
8bd3b9e474 | ||
|
|
eff5ff580b | ||
|
|
c45f863ed8 | ||
|
|
414d4e356d | ||
|
|
ef697eb781 | ||
|
|
0631ffe831 | ||
|
|
7444d8517a | ||
|
|
3480043e40 | ||
|
|
619b6e7516 | ||
|
|
0123ca44fb | ||
|
|
7929aafe2a | ||
|
|
dc433f8dc9 | ||
|
|
8ccaeeab60 | ||
|
|
043a28eb33 | ||
|
|
c329402f5d | ||
|
|
23e6ad6e1f | ||
|
|
e6de78c1fa | ||
|
|
a670708f93 | ||
|
|
4ebe2fb5f4 | ||
|
|
3403b2039d | ||
|
|
e30ad31e0f | ||
|
|
81e0bad739 | ||
|
|
7d07d738dc | ||
|
|
0fae584e65 | ||
|
|
9e83085f2a | ||
|
|
41a00c68ac | ||
|
|
e3b99bf339 | ||
|
|
5007a87d3a | ||
|
|
60e65a37a6 | ||
|
|
d37d0e942c | ||
|
|
98042d8dbd | ||
|
|
af4b826b68 | ||
|
|
253a57ca01 | ||
|
|
caf98b4dfe | ||
|
|
398f71fd00 | ||
|
|
e1301ade96 | ||
|
|
7a23f82192 | ||
|
|
715bcc4aa1 | ||
|
|
0c74838740 | ||
|
|
4b05b6da7b | ||
|
|
375844ff1a | ||
|
|
1d207379cb | ||
|
|
fb49cb71e3 | ||
|
|
9618efbcde | ||
|
|
bb2210b06a | ||
|
|
917052723d | ||
|
|
fef85cadeb | ||
|
|
4a05fb6b28 | ||
|
|
6644ce53f2 | ||
|
|
72f0b89fdc | ||
|
|
41a97a6609 | ||
|
|
38064d6ad5 | ||
|
|
ae6945cedf | ||
|
|
3132d1b032 | ||
|
|
2716ae29bd | ||
|
|
1c50c2b6af | ||
|
|
cf6d16b439 | ||
|
|
60686f55ff | ||
|
|
47d7ace3a7 | ||
|
|
2d3779ec27 | ||
|
|
595071b608 | ||
|
|
57ef717080 | ||
|
|
eb27d1482b | ||
|
|
f57972ead7 | ||
|
|
168eaf538b | ||
|
|
1560455ca3 | ||
|
|
028475a193 | ||
|
|
f7a6dbe39b | ||
|
|
e573a490c9 | ||
|
|
ce3281e70d | ||
|
|
0fbfd160c9 | ||
|
|
20759017e6 | ||
|
|
69e0aab73e | ||
|
|
7ed6733fb7 | ||
|
|
9718ab8579 | ||
|
|
2687a4a018 | ||
|
|
2d9c60dea1 | ||
|
|
841be069b7 | ||
|
|
7833132917 | ||
|
|
e9e63b0983 | ||
|
|
4df470b869 | ||
|
|
89600f6091 | ||
|
|
f986a575e8 | ||
|
|
9c2fe8d21f | ||
|
|
8bcbb9249e | ||
|
|
a95d50c0af | ||
|
|
5db7d3577b | ||
|
|
c53a0ca1c4 | ||
|
|
6fd3d1788a | ||
|
|
087c1975e5 | ||
|
|
3713cbecc3 | ||
|
|
6046789fa4 | ||
|
|
3ea69b180c | ||
|
|
db6e977e3a | ||
|
|
a5c776c846 | ||
|
|
5a566c028a | ||
|
|
ff43c74d8d | ||
|
|
3c7255569c | ||
|
|
4a92ec4d2d | ||
|
|
9bbccb4082 | ||
|
|
4f62314646 | ||
|
|
cb49d0d947 | ||
|
|
89f7874fc6 | ||
|
|
221917e80b | ||
|
|
37d41bd215 | ||
|
|
8a96b8bec4 | ||
|
|
02ee113b95 | ||
|
|
f71dd78915 | ||
|
|
cd5619a05b | ||
|
|
a63a30c76b | ||
|
|
f5ba8be182 | ||
|
|
a9f76322bd | ||
|
|
ed39269c80 | ||
|
|
09426dcd36 | ||
|
|
17941882a9 | ||
|
|
70ab8032a0 | ||
|
|
8360bdc50a | ||
|
|
6837176ec7 | ||
|
|
5e9b4244e7 | ||
|
|
9b6a308958 | ||
|
|
71e327653a | ||
|
|
a56711796f | ||
|
|
09495f2a7c | ||
|
|
484643e114 | ||
|
|
da91aabc35 | ||
|
|
c654398981 | ||
|
|
47a90ec2a1 | ||
|
|
2875e22d0b | ||
|
|
c5d14e0075 | ||
|
|
84e06c363c | ||
|
|
5b9ccc5065 | ||
|
|
6ca1a7ccc7 | ||
|
|
9d666be5d4 | ||
|
|
65de7edcde | ||
|
|
0cdff0d368 | ||
|
|
f87220a908 | ||
|
|
30ea0c6499 | ||
|
|
9501e35c60 | ||
|
|
5ac9d17bdf | ||
|
|
cb14992ddc | ||
|
|
e88372fc8c | ||
|
|
b320662d67 | ||
|
|
ce353cd4d9 | ||
|
|
4befd33866 | ||
|
|
4b36e3ac44 | ||
|
|
f507bc8f9e | ||
|
|
14c88f4a6d | ||
|
|
3e388c2857 | ||
|
|
cfe1209d61 | ||
|
|
5a88a7c22c | ||
|
|
8c661c4401 | ||
|
|
e6f256d640 | ||
|
|
ede354166b | ||
|
|
282a8ce78e | ||
|
|
08fe04f1ee | ||
|
|
082d14a9ba | ||
|
|
617674ce43 |
8
.github/workflows/ci.yml
vendored
8
.github/workflows/ci.yml
vendored
@@ -9,18 +9,18 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
os: ["windows-latest", "ubuntu-latest", "macOS-latest"]
|
||||
go: ["1.21.x"]
|
||||
go: ["1.24.x"]
|
||||
runs-on: ${{ matrix.os }}
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 1
|
||||
- uses: WillAbides/setup-go-faster@v1.8.0
|
||||
- uses: actions/setup-go@v6
|
||||
with:
|
||||
go-version: ${{ matrix.go }}
|
||||
- run: "go test -race ./..."
|
||||
- uses: dominikh/staticcheck-action@v1.2.0
|
||||
- uses: dominikh/staticcheck-action@v1.4.0
|
||||
with:
|
||||
version: "2023.1.2"
|
||||
version: "2025.1.1"
|
||||
install-go: false
|
||||
cache-key: ${{ matrix.go }}
|
||||
|
||||
193
README.md
193
README.md
@@ -4,14 +4,13 @@
|
||||
[](https://pkg.go.dev/github.com/Control-D-Inc/ctrld)
|
||||
[](https://goreportcard.com/report/github.com/Control-D-Inc/ctrld)
|
||||
|
||||

|
||||

|
||||
|
||||
A highly configurable DNS forwarding proxy with support for:
|
||||
- Multiple listeners for incoming queries
|
||||
- Multiple upstreams with fallbacks
|
||||
- Multiple network policy driven DNS query steering
|
||||
- Multiple network policy driven DNS query steering (via network cidr, MAC address or FQDN)
|
||||
- Policy driven domain based "split horizon" DNS with wildcard support
|
||||
- Integrations with common router vendors and firmware
|
||||
- LAN client discovery via DHCP, mDNS, ARP, NDP, hosts file parsing
|
||||
- Prometheus metrics exporter
|
||||
|
||||
@@ -26,61 +25,58 @@ All DNS protocols are supported, including:
|
||||
- `DNS-over-QUIC`
|
||||
|
||||
# Use Cases
|
||||
1. Use secure DNS protocols on networks and devices that don't natively support them (legacy routers, legacy OSes, TVs, smart toasters).
|
||||
1. Use secure DNS protocols on networks and devices that don't natively support them (legacy OSes, TVs, smart toasters).
|
||||
2. Create source IP based DNS routing policies with variable secure DNS upstreams. Subnet 1 (admin) uses upstream resolver A, while Subnet 2 (employee) uses upstream resolver B.
|
||||
3. Create destination IP based DNS routing policies with variable secure DNS upstreams. Listener 1 uses upstream resolver C, while Listener 2 uses upstream resolver D.
|
||||
4. Create domain level "split horizon" DNS routing policies to send internal domains (*.company.int) to a local DNS server, while everything else goes to another upstream.
|
||||
5. Deploy on a router and create LAN client specific DNS routing policies from a web GUI (When using ControlD.com).
|
||||
|
||||
|
||||
## OS Support
|
||||
- Windows (386, amd64, arm)
|
||||
- Mac (amd64, arm64)
|
||||
- Windows Desktop (386, amd64, arm)
|
||||
- MacOS (amd64, arm64)
|
||||
- Linux (386, amd64, arm, mips)
|
||||
- FreeBSD
|
||||
- Common routers (See Router Mode below)
|
||||
- FreeBSD (386, amd64, arm)
|
||||
|
||||
# Install
|
||||
There are several ways to download and install `ctrld.
|
||||
There are several ways to download and install `ctrld`.
|
||||
|
||||
## Quick Install
|
||||
The simplest way to download and install `ctrld` is to use the following installer command on any UNIX-like platform:
|
||||
|
||||
```shell
|
||||
sh -c 'sh -c "$(curl -sL https://api.controld.com/dl)"'
|
||||
sh -c 'sh -c "$(curl -sL https://api.controld.com/dl?version=2)"'
|
||||
```
|
||||
|
||||
Windows user and prefer Powershell (who doesn't)? No problem, execute this command instead in administrative cmd:
|
||||
Windows user and prefer Powershell (who doesn't)? No problem, execute this command instead in administrative PowerShell:
|
||||
```shell
|
||||
powershell -Command "(Invoke-WebRequest -Uri 'https://api.controld.com/dl' -UseBasicParsing).Content | Set-Content 'ctrld_install.bat'" && ctrld_install.bat
|
||||
(Invoke-WebRequest -Uri 'https://api.controld.com/dl/ps1?version=2' -UseBasicParsing).Content | Set-Content "$env:TEMPctrld_install.ps1"; Invoke-Expression "& '$env:TEMPctrld_install.ps1'"
|
||||
```
|
||||
|
||||
Or you can pull and run a Docker container from [Docker Hub](https://hub.docker.com/r/controldns/ctrld)
|
||||
```
|
||||
$ docker pull controldns/ctrld
|
||||
```shell
|
||||
docker run -d --name=ctrld -p 127.0.0.1:53:53/tcp -p 127.0.0.1:53:53/udp controldns/ctrld:latest
|
||||
```
|
||||
|
||||
## Download Manually
|
||||
Alternatively, if you know what you're doing you can download pre-compiled binaries from the [Releases](https://github.com/Control-D-Inc/ctrld/releases) section for the appropriate platform.
|
||||
|
||||
## Build
|
||||
Lastly, you can build `ctrld` from source which requires `go1.21+`:
|
||||
Lastly, you can build `ctrld` from source which requires `go1.23+`:
|
||||
|
||||
```shell
|
||||
$ go build ./cmd/ctrld
|
||||
go build ./cmd/ctrld
|
||||
```
|
||||
|
||||
or
|
||||
|
||||
```shell
|
||||
$ go install github.com/Control-D-Inc/ctrld/cmd/ctrld@latest
|
||||
go install github.com/Control-D-Inc/ctrld/cmd/ctrld@latest
|
||||
```
|
||||
|
||||
or
|
||||
|
||||
```
|
||||
$ docker build -t controldns/ctrld . -f docker/Dockerfile
|
||||
$ docker run -d --name=ctrld -p 53:53/tcp -p 53:53/udp controldns/ctrld --cd=RESOLVER_ID_GOES_HERE -vv
|
||||
```shell
|
||||
docker build -t controldns/ctrld . -f docker/Dockerfile
|
||||
```
|
||||
|
||||
|
||||
@@ -101,101 +97,118 @@ Usage:
|
||||
|
||||
Available Commands:
|
||||
run Run the DNS proxy server
|
||||
service Manage ctrld service
|
||||
start Quick start service and configure DNS on interface
|
||||
stop Quick stop service and remove DNS from interface
|
||||
restart Restart the ctrld service
|
||||
reload Reload the ctrld service
|
||||
status Show status of the ctrld service
|
||||
uninstall Stop and uninstall the ctrld service
|
||||
service Manage ctrld service
|
||||
clients Manage clients
|
||||
upgrade Upgrading ctrld to latest version
|
||||
log Manage runtime debug logs
|
||||
|
||||
Flags:
|
||||
-h, --help help for ctrld
|
||||
-s, --silent do not write any log output
|
||||
-v, --verbose count verbose log output, "-v" basic logging, "-vv" debug level logging
|
||||
-v, --verbose count verbose log output, "-v" basic logging, "-vv" debug logging
|
||||
--version version for ctrld
|
||||
|
||||
Use "ctrld [command] --help" for more information about a command.
|
||||
```
|
||||
|
||||
## Basic Run Mode
|
||||
To start the server with default configuration, simply run: `./ctrld run`. This will create a generic `ctrld.toml` file in the **working directory** and start the application in foreground.
|
||||
1. Start the server
|
||||
```
|
||||
$ sudo ./ctrld run
|
||||
This is the most basic way to run `ctrld`, in foreground mode. Unless you already have a config file, a default one will be generated.
|
||||
|
||||
### Command
|
||||
|
||||
Windows (Admin Shell)
|
||||
```shell
|
||||
ctrld.exe run
|
||||
```
|
||||
|
||||
2. Run a test query using a DNS client, for example, `dig`:
|
||||
Linux or Macos
|
||||
```shell
|
||||
sudo ctrld run
|
||||
```
|
||||
|
||||
You can then run a test query using a DNS client, for example, `dig`:
|
||||
```
|
||||
$ dig verify.controld.com @127.0.0.1 +short
|
||||
api.controld.com.
|
||||
147.185.34.1
|
||||
```
|
||||
|
||||
If `verify.controld.com` resolves, you're successfully using the default Control D upstream. From here, you can start editing the config file and go nuts with it. To enforce a new config, restart the server.
|
||||
If `verify.controld.com` resolves, you're successfully using the default Control D upstream. From here, you can start editing the config file that was generated. To enforce a new config, restart the server.
|
||||
|
||||
## Service Mode
|
||||
To run the application in service mode on any Windows, MacOS, Linux distibution or supported router, simply run: `./ctrld start` as system/root user. This will create a generic `ctrld.toml` file in the **user home** directory (on Windows) or `/etc/controld/` (almost everywhere else), start the system service, and configure the listener on the default network interface. Service will start on OS boot.
|
||||
This mode will run the application as a background system service on any Windows, MacOS, Linux or FreeBSD distribution. This will create a generic `ctrld.toml` file in the **C:\ControlD** directory (on Windows) or `/etc/controld/` (almost everywhere else), start the system service, and **configure the listener on all physical network interface**. Service will start on OS boot.
|
||||
|
||||
When Control D upstreams are used, `ctrld` willl [relay your network topology](https://docs.controld.com/docs/device-clients) to Control D (LAN IPs, MAC addresses, and hostnames), and you will be able to see your LAN devices in the web panel, view analytics and apply unique profiles to them.
|
||||
### Command
|
||||
|
||||
In order to stop the service, and restore your DNS to original state, simply run `./ctrld stop`. If you wish to stop and uninstall the service permanently, run `./ctrld uninstall`.
|
||||
Windows (Admin Shell)
|
||||
```shell
|
||||
ctrld.exe start
|
||||
```
|
||||
|
||||
Linux or Macos
|
||||
```
|
||||
sudo ctrld start
|
||||
```
|
||||
|
||||
### Supported Routers
|
||||
You can run `ctrld` on any supported router, which will function similarly to the Service Mode mentioned above. The list of supported routers and firmware includes:
|
||||
- Asus Merlin
|
||||
- DD-WRT
|
||||
- Firewalla
|
||||
- FreshTomato
|
||||
- GL.iNet
|
||||
- OpenWRT
|
||||
- pfSense / OPNsense
|
||||
- Synology
|
||||
- Ubiquiti (UniFi, EdgeOS)
|
||||
If `ctrld` is not in your system path (you installed it manually), you will need to run the above commands from the directory where you installed `ctrld`.
|
||||
|
||||
`ctrld` will attempt to interface with dnsmasq whenever possible and set itself as the upstream, while running on port 5354. On FreeBSD based OSes, `ctrld` will terminate dnsmasq and unbound in order to be able to listen on port 53 directly.
|
||||
In order to stop the service, and restore your DNS to original state, simply run `ctrld stop`. If you wish to stop and uninstall the service permanently, run `ctrld uninstall`.
|
||||
|
||||
## Unmanaged Service Mode
|
||||
This mode functions similarly to the "Service Mode" above except it will simply start a system service and the config defined listeners, but **will not make any changes to any network interfaces**. You can then set the `ctrld` listener(s) IP on the desired network interfaces manually.
|
||||
|
||||
### Control D Auto Configuration
|
||||
Application can be started with a specific resolver config, instead of the default one. Simply supply your Resolver ID with a `--cd` flag, when using the `run` (foreground) or `start` (service) modes.
|
||||
### Command
|
||||
|
||||
The following command will start the application in foreground mode, using the free "p2" resolver, which blocks Ads & Trackers.
|
||||
Windows (Admin Shell)
|
||||
```shell
|
||||
ctrld.exe service start
|
||||
```
|
||||
|
||||
```shell
|
||||
./ctrld run --cd p2
|
||||
```
|
||||
|
||||
Alternatively, you can use your own personal Control D Device resolver, and start the application in service mode. Your resolver ID is displayed on the "Show Resolvers" screen for the relevant Control D Device.
|
||||
|
||||
```shell
|
||||
./ctrld start --cd abcd1234
|
||||
```
|
||||
|
||||
Once you run the above commands (in service mode only), the following things will happen:
|
||||
- You resolver configuration will be fetched from the API, and config file templated with the resolver data
|
||||
- Application will start as a service, and keep running (even after reboot) until you run the `stop` or `uninstall` sub-commands
|
||||
- Your default network interface will be updated to use the listener started by the service
|
||||
- All OS DNS queries will be sent to the listener
|
||||
Linux or Macos
|
||||
```shell
|
||||
sudo ctrld service start
|
||||
```
|
||||
|
||||
# Configuration
|
||||
See [Configuration Docs](docs/config.md).
|
||||
`ctrld` can be configured in variety of different ways, which include: API, local config file or via cli launch args.
|
||||
|
||||
## Example
|
||||
- Start `listener.0` on 127.0.0.1:53
|
||||
- Accept queries from any source address
|
||||
- Send all queries to `upstream.0` via DoH protocol
|
||||
## API Based Auto Configuration
|
||||
Application can be started with a specific Control D resolver config, instead of the default one. Simply supply your Resolver ID with a `--cd` flag, when using the `start` (service) mode. This mode is used when the 1 liner installer command from the Control D onboarding guide is executed.
|
||||
|
||||
### Default Config
|
||||
The following command will use your own personal Control D Device resolver, and start the application in service mode. Your resolver ID is displayed on the "Show Resolvers" screen for the relevant Control D Endpoint.
|
||||
|
||||
Windows (Admin Shell)
|
||||
```shell
|
||||
ctrld.exe start --cd abcd1234
|
||||
```
|
||||
|
||||
Linux or Macos
|
||||
```shell
|
||||
sudo ctrld start --cd abcd1234
|
||||
```
|
||||
|
||||
Once you run the above command, the following things will happen:
|
||||
- You resolver configuration will be fetched from the API, and config file templated with the resolver data
|
||||
- Application will start as a service, and keep running (even after reboot) until you run the `stop` or `uninstall` sub-commands
|
||||
- All physical network interface will be updated to use the listener started by the service
|
||||
- All DNS queries will be sent to the listener
|
||||
|
||||
## Manual Configuration
|
||||
`ctrld` is entirely config driven and can be configured in many different ways, please see [Configuration Docs](docs/config.md).
|
||||
|
||||
### Example
|
||||
```toml
|
||||
[listener]
|
||||
|
||||
[listener.0]
|
||||
ip = ""
|
||||
port = 0
|
||||
restricted = false
|
||||
ip = '0.0.0.0'
|
||||
port = 53
|
||||
|
||||
[network]
|
||||
|
||||
@@ -203,10 +216,6 @@ See [Configuration Docs](docs/config.md).
|
||||
cidrs = ["0.0.0.0/0"]
|
||||
name = "Network 0"
|
||||
|
||||
[service]
|
||||
log_level = "info"
|
||||
log_path = ""
|
||||
|
||||
[upstream]
|
||||
|
||||
[upstream.0]
|
||||
@@ -215,28 +224,26 @@ See [Configuration Docs](docs/config.md).
|
||||
name = "Control D - Anti-Malware"
|
||||
timeout = 5000
|
||||
type = "doh"
|
||||
|
||||
[upstream.1]
|
||||
bootstrap_ip = "76.76.2.11"
|
||||
endpoint = "p2.freedns.controld.com"
|
||||
name = "Control D - No Ads"
|
||||
timeout = 3000
|
||||
type = "doq"
|
||||
|
||||
```
|
||||
|
||||
`ctrld` will pick a working config for `listener.0` then writing the default config to disk for the first run.
|
||||
The above basic config will:
|
||||
- Start listener on 0.0.0.0:53
|
||||
- Accept queries from any source address
|
||||
- Send all queries to `https://freedns.controld.com/p1` using DoH protocol
|
||||
|
||||
## Advanced Configuration
|
||||
The above is the most basic example, which will work out of the box. If you're looking to do advanced configurations using policies, see [Configuration Docs](docs/config.md) for complete documentation of the config file.
|
||||
## CLI Args
|
||||
If you're unable to use a config file, `ctrld` can be be supplied with basic configuration via launch arguments, in [Ephemeral Mode](docs/ephemeral_mode.md).
|
||||
|
||||
You can also supply configuration via launch argeuments, in [Ephemeral Mode](docs/ephemeral_mode.md).
|
||||
### Example
|
||||
```
|
||||
ctrld run --listen=127.0.0.1:53 --primary_upstream=https://freedns.controld.com/p2 --secondary_upstream=10.0.10.1:53 --domains=*.company.int,very-secure.local --log /path/to/log.log
|
||||
```
|
||||
|
||||
The above will start a foreground process and:
|
||||
- Listen on `127.0.0.1:53` for DNS queries
|
||||
- Forward all queries to `https://freedns.controld.com/p2` using DoH protocol, while...
|
||||
- Excluding `*.company.int` and `very-secure.local` matching queries, that are forwarded to `10.0.10.1:53`
|
||||
- Write a debug log to `/path/to/log.log`
|
||||
|
||||
## Contributing
|
||||
See [Contribution Guideline](./docs/contributing.md)
|
||||
|
||||
## Roadmap
|
||||
The following functionality is on the roadmap and will be available in future releases.
|
||||
- DNS intercept mode
|
||||
- Direct listener mode
|
||||
- Support for more routers (let us know which ones)
|
||||
|
||||
10
cmd/cli/ad_others.go
Normal file
10
cmd/cli/ad_others.go
Normal file
@@ -0,0 +1,10 @@
|
||||
//go:build !windows
|
||||
|
||||
package cli
|
||||
|
||||
import (
|
||||
"github.com/Control-D-Inc/ctrld"
|
||||
)
|
||||
|
||||
// addExtraSplitDnsRule adds split DNS rule if present.
|
||||
func addExtraSplitDnsRule(_ *ctrld.Config) bool { return false }
|
||||
73
cmd/cli/ad_windows.go
Normal file
73
cmd/cli/ad_windows.go
Normal file
@@ -0,0 +1,73 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"io"
|
||||
"log"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/microsoft/wmi/pkg/base/host"
|
||||
hh "github.com/microsoft/wmi/pkg/hardware/host"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld"
|
||||
)
|
||||
|
||||
// addExtraSplitDnsRule adds split DNS rule for domain if it's part of active directory.
|
||||
func addExtraSplitDnsRule(cfg *ctrld.Config) bool {
|
||||
domain, err := getActiveDirectoryDomain()
|
||||
if err != nil {
|
||||
mainLog.Load().Debug().Msgf("Unable to get active directory domain: %v", err)
|
||||
return false
|
||||
}
|
||||
if domain == "" {
|
||||
mainLog.Load().Debug().Msg("No active directory domain found")
|
||||
return false
|
||||
}
|
||||
// Network rules are lowercase during toml config marshaling,
|
||||
// lowercase the domain here too for consistency.
|
||||
domain = strings.ToLower(domain)
|
||||
domainRuleAdded := addSplitDnsRule(cfg, domain)
|
||||
wildcardDomainRuleRuleAdded := addSplitDnsRule(cfg, "*."+strings.TrimPrefix(domain, "."))
|
||||
return domainRuleAdded || wildcardDomainRuleRuleAdded
|
||||
}
|
||||
|
||||
// addSplitDnsRule adds split-rule for given domain if there's no existed rule.
|
||||
// The return value indicates whether the split-rule was added or not.
|
||||
func addSplitDnsRule(cfg *ctrld.Config, domain string) bool {
|
||||
for n, lc := range cfg.Listener {
|
||||
if lc.Policy == nil {
|
||||
lc.Policy = &ctrld.ListenerPolicyConfig{}
|
||||
}
|
||||
for _, rule := range lc.Policy.Rules {
|
||||
if _, ok := rule[domain]; ok {
|
||||
mainLog.Load().Debug().Msgf("Split-rule %q already existed for listener.%s", domain, n)
|
||||
return false
|
||||
}
|
||||
}
|
||||
mainLog.Load().Debug().Msgf("Adding split-rule %q for listener.%s", domain, n)
|
||||
lc.Policy.Rules = append(lc.Policy.Rules, ctrld.Rule{domain: []string{}})
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// getActiveDirectoryDomain returns AD domain name of this computer.
|
||||
func getActiveDirectoryDomain() (string, error) {
|
||||
log.SetOutput(io.Discard)
|
||||
defer log.SetOutput(os.Stderr)
|
||||
whost := host.NewWmiLocalHost()
|
||||
cs, err := hh.GetComputerSystem(whost)
|
||||
if cs != nil {
|
||||
defer cs.Close()
|
||||
}
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
pod, err := cs.GetPropertyPartOfDomain()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if pod {
|
||||
return cs.GetPropertyDomain()
|
||||
}
|
||||
return "", nil
|
||||
}
|
||||
71
cmd/cli/ad_windows_test.go
Normal file
71
cmd/cli/ad_windows_test.go
Normal file
@@ -0,0 +1,71 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld"
|
||||
"github.com/Control-D-Inc/ctrld/testhelper"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func Test_getActiveDirectoryDomain(t *testing.T) {
|
||||
start := time.Now()
|
||||
domain, err := getActiveDirectoryDomain()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Logf("Using Windows API takes: %d", time.Since(start).Milliseconds())
|
||||
|
||||
start = time.Now()
|
||||
domainPowershell, err := getActiveDirectoryDomainPowershell()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Logf("Using Powershell takes: %d", time.Since(start).Milliseconds())
|
||||
|
||||
if domain != domainPowershell {
|
||||
t.Fatalf("result mismatch, want: %v, got: %v", domainPowershell, domain)
|
||||
}
|
||||
}
|
||||
|
||||
func getActiveDirectoryDomainPowershell() (string, error) {
|
||||
cmd := "$obj = Get-WmiObject Win32_ComputerSystem; if ($obj.PartOfDomain) { $obj.Domain }"
|
||||
output, err := powershell(cmd)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to get domain name: %w, output:\n\n%s", err, string(output))
|
||||
}
|
||||
return string(output), nil
|
||||
}
|
||||
|
||||
func Test_addSplitDnsRule(t *testing.T) {
|
||||
newCfg := func(domains ...string) *ctrld.Config {
|
||||
cfg := testhelper.SampleConfig(t)
|
||||
lc := cfg.Listener["0"]
|
||||
for _, domain := range domains {
|
||||
lc.Policy.Rules = append(lc.Policy.Rules, ctrld.Rule{domain: []string{}})
|
||||
}
|
||||
return cfg
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
cfg *ctrld.Config
|
||||
domain string
|
||||
added bool
|
||||
}{
|
||||
{"added", newCfg(), "example.com", true},
|
||||
{"TLD existed", newCfg("example.com"), "*.example.com", true},
|
||||
{"wildcard existed", newCfg("*.example.com"), "example.com", true},
|
||||
{"not added TLD", newCfg("example.com", "*.example.com"), "example.com", false},
|
||||
{"not added wildcard", newCfg("example.com", "*.example.com"), "*.example.com", false},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
added := addSplitDnsRule(tc.cfg, tc.domain)
|
||||
assert.Equal(t, tc.added, added)
|
||||
})
|
||||
}
|
||||
}
|
||||
5
cmd/cli/cgo.go
Normal file
5
cmd/cli/cgo.go
Normal file
@@ -0,0 +1,5 @@
|
||||
//go:build cgo
|
||||
|
||||
package cli
|
||||
|
||||
const cgoEnabled = true
|
||||
1784
cmd/cli/cli.go
1784
cmd/cli/cli.go
File diff suppressed because it is too large
Load Diff
141
cmd/cli/commands_clients.go
Normal file
141
cmd/cli/commands_clients.go
Normal file
@@ -0,0 +1,141 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/kardianos/service"
|
||||
"github.com/olekukonko/tablewriter"
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld/internal/clientinfo"
|
||||
)
|
||||
|
||||
// ClientsCommand handles clients-related operations
|
||||
type ClientsCommand struct {
|
||||
controlClient *controlClient
|
||||
}
|
||||
|
||||
// NewClientsCommand creates a new clients command handler
|
||||
func NewClientsCommand() (*ClientsCommand, error) {
|
||||
dir, err := socketDir()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to find ctrld home dir: %w", err)
|
||||
}
|
||||
|
||||
cc := newControlClient(filepath.Join(dir, ctrldControlUnixSock))
|
||||
return &ClientsCommand{
|
||||
controlClient: cc,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ListClients lists all connected clients
|
||||
func (cc *ClientsCommand) ListClients(cmd *cobra.Command, args []string) error {
|
||||
// Check service status first
|
||||
sc := NewServiceCommand()
|
||||
s, _, err := sc.initializeServiceManager()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
status, err := s.Status()
|
||||
if errors.Is(err, service.ErrNotInstalled) {
|
||||
mainLog.Load().Warn().Msg("Service not installed")
|
||||
return nil
|
||||
}
|
||||
if status == service.StatusStopped {
|
||||
mainLog.Load().Warn().Msg("Service is not running")
|
||||
return nil
|
||||
}
|
||||
|
||||
resp, err := cc.controlClient.post(listClientsPath, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get clients: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
var clients []*clientinfo.Client
|
||||
if err := json.NewDecoder(resp.Body).Decode(&clients); err != nil {
|
||||
return fmt.Errorf("failed to decode clients result: %w", err)
|
||||
}
|
||||
|
||||
map2Slice := func(m map[string]struct{}) []string {
|
||||
s := make([]string, 0, len(m))
|
||||
for k := range m {
|
||||
if k == "" { // skip empty source from output.
|
||||
continue
|
||||
}
|
||||
s = append(s, k)
|
||||
}
|
||||
sort.Strings(s)
|
||||
return s
|
||||
}
|
||||
|
||||
// If metrics is enabled, server set this for all clients, so we can check only the first one.
|
||||
// Ideally, we may have a field in response to indicate that query count should be shown, but
|
||||
// it would break earlier version of ctrld, which only look list of clients in response.
|
||||
withQueryCount := len(clients) > 0 && clients[0].IncludeQueryCount
|
||||
data := make([][]string, len(clients))
|
||||
for i, c := range clients {
|
||||
row := []string{
|
||||
c.IP.String(),
|
||||
c.Hostname,
|
||||
c.Mac,
|
||||
strings.Join(map2Slice(c.Source), ","),
|
||||
}
|
||||
if withQueryCount {
|
||||
row = append(row, strconv.FormatInt(c.QueryCount, 10))
|
||||
}
|
||||
data[i] = row
|
||||
}
|
||||
|
||||
table := tablewriter.NewWriter(os.Stdout)
|
||||
headers := []string{"IP", "Hostname", "Mac", "Discovered"}
|
||||
if withQueryCount {
|
||||
headers = append(headers, "Queries")
|
||||
}
|
||||
table.SetHeader(headers)
|
||||
table.SetAutoFormatHeaders(false)
|
||||
table.AppendBulk(data)
|
||||
table.Render()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// InitClientsCmd creates the clients command with proper logic
|
||||
func InitClientsCmd(rootCmd *cobra.Command) *cobra.Command {
|
||||
listClientsCmd := &cobra.Command{
|
||||
Use: "list",
|
||||
Short: "List clients that ctrld discovered",
|
||||
Args: cobra.NoArgs,
|
||||
PreRun: func(cmd *cobra.Command, args []string) {
|
||||
checkHasElevatedPrivilege()
|
||||
},
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
cc, err := NewClientsCommand()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return cc.ListClients(cmd, args)
|
||||
},
|
||||
}
|
||||
|
||||
clientsCmd := &cobra.Command{
|
||||
Use: "clients",
|
||||
Short: "Manage clients",
|
||||
Args: cobra.OnlyValidArgs,
|
||||
ValidArgs: []string{
|
||||
listClientsCmd.Use,
|
||||
},
|
||||
}
|
||||
clientsCmd.AddCommand(listClientsCmd)
|
||||
rootCmd.AddCommand(clientsCmd)
|
||||
|
||||
return clientsCmd
|
||||
}
|
||||
87
cmd/cli/commands_interfaces.go
Normal file
87
cmd/cli/commands_interfaces.go
Normal file
@@ -0,0 +1,87 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
// InterfacesCommand handles interfaces-related operations
|
||||
type InterfacesCommand struct{}
|
||||
|
||||
// NewInterfacesCommand creates a new interfaces command handler
|
||||
func NewInterfacesCommand() (*InterfacesCommand, error) {
|
||||
return &InterfacesCommand{}, nil
|
||||
}
|
||||
|
||||
// ListInterfaces lists all network interfaces
|
||||
func (ic *InterfacesCommand) ListInterfaces(cmd *cobra.Command, args []string) error {
|
||||
withEachPhysicalInterfaces("", "Interface list", func(i *net.Interface) error {
|
||||
fmt.Printf("Index : %d\n", i.Index)
|
||||
fmt.Printf("Name : %s\n", i.Name)
|
||||
var status string
|
||||
if i.Flags&net.FlagUp != 0 {
|
||||
status = "Up"
|
||||
} else {
|
||||
status = "Down"
|
||||
}
|
||||
fmt.Printf("Status: %s\n", status)
|
||||
addrs, _ := i.Addrs()
|
||||
for i, ipaddr := range addrs {
|
||||
if i == 0 {
|
||||
fmt.Printf("Addrs : %v\n", ipaddr)
|
||||
continue
|
||||
}
|
||||
fmt.Printf(" %v\n", ipaddr)
|
||||
}
|
||||
nss, err := currentStaticDNS(i)
|
||||
if err != nil {
|
||||
mainLog.Load().Warn().Err(err).Msg("Failed to get DNS")
|
||||
}
|
||||
if len(nss) == 0 {
|
||||
nss = currentDNS(i)
|
||||
}
|
||||
for i, dns := range nss {
|
||||
if i == 0 {
|
||||
fmt.Printf("DNS : %s\n", dns)
|
||||
continue
|
||||
}
|
||||
fmt.Printf(" : %s\n", dns)
|
||||
}
|
||||
println()
|
||||
return nil
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
// InitInterfacesCmd creates the interfaces command with proper logic
|
||||
func InitInterfacesCmd(_ *cobra.Command) *cobra.Command {
|
||||
listInterfacesCmd := &cobra.Command{
|
||||
Use: "list",
|
||||
Short: "List network interfaces",
|
||||
Args: cobra.NoArgs,
|
||||
PreRun: func(cmd *cobra.Command, args []string) {
|
||||
checkHasElevatedPrivilege()
|
||||
},
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
ic, err := NewInterfacesCommand()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return ic.ListInterfaces(cmd, args)
|
||||
},
|
||||
}
|
||||
|
||||
interfacesCmd := &cobra.Command{
|
||||
Use: "interfaces",
|
||||
Short: "Manage network interfaces",
|
||||
Args: cobra.OnlyValidArgs,
|
||||
ValidArgs: []string{
|
||||
listInterfacesCmd.Use,
|
||||
},
|
||||
}
|
||||
interfacesCmd.AddCommand(listInterfacesCmd)
|
||||
|
||||
return interfacesCmd
|
||||
}
|
||||
175
cmd/cli/commands_log.go
Normal file
175
cmd/cli/commands_log.go
Normal file
@@ -0,0 +1,175 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/docker/go-units"
|
||||
"github.com/kardianos/service"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
// LogCommand handles log-related operations
|
||||
type LogCommand struct {
|
||||
controlClient *controlClient
|
||||
}
|
||||
|
||||
// NewLogCommand creates a new log command handler
|
||||
func NewLogCommand() (*LogCommand, error) {
|
||||
dir, err := socketDir()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to find ctrld home dir: %w", err)
|
||||
}
|
||||
|
||||
cc := newControlClient(filepath.Join(dir, ctrldControlUnixSock))
|
||||
return &LogCommand{
|
||||
controlClient: cc,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// warnRuntimeLoggingNotEnabled logs a warning about runtime logging not being enabled
|
||||
func (lc *LogCommand) warnRuntimeLoggingNotEnabled() {
|
||||
mainLog.Load().Warn().Msg("Runtime debug logging is not enabled")
|
||||
mainLog.Load().Warn().Msg(`ctrld may be running without "--cd" flag or logging is already enabled`)
|
||||
}
|
||||
|
||||
// SendLogs sends runtime debug logs to ControlD
|
||||
func (lc *LogCommand) SendLogs(cmd *cobra.Command, args []string) error {
|
||||
sc := NewServiceCommand()
|
||||
s, _, err := sc.initializeServiceManager()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
status, err := s.Status()
|
||||
if errors.Is(err, service.ErrNotInstalled) {
|
||||
mainLog.Load().Warn().Msg("Service not installed")
|
||||
return nil
|
||||
}
|
||||
if status == service.StatusStopped {
|
||||
mainLog.Load().Warn().Msg("Service is not running")
|
||||
return nil
|
||||
}
|
||||
|
||||
resp, err := lc.controlClient.post(sendLogsPath, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to send logs: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
switch resp.StatusCode {
|
||||
case http.StatusServiceUnavailable:
|
||||
mainLog.Load().Warn().Msg("Runtime logs could only be sent once per minute")
|
||||
return nil
|
||||
case http.StatusMovedPermanently:
|
||||
lc.warnRuntimeLoggingNotEnabled()
|
||||
return nil
|
||||
}
|
||||
|
||||
var logs logSentResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&logs); err != nil {
|
||||
return fmt.Errorf("failed to decode sent logs result: %w", err)
|
||||
}
|
||||
|
||||
if logs.Error != "" {
|
||||
return fmt.Errorf("failed to send logs: %s", logs.Error)
|
||||
}
|
||||
|
||||
mainLog.Load().Notice().Msgf("Sent %s of runtime logs", units.BytesSize(float64(logs.Size)))
|
||||
return nil
|
||||
}
|
||||
|
||||
// ViewLogs views current runtime debug logs
|
||||
func (lc *LogCommand) ViewLogs(cmd *cobra.Command, args []string) error {
|
||||
sc := NewServiceCommand()
|
||||
s, _, err := sc.initializeServiceManager()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
status, err := s.Status()
|
||||
if errors.Is(err, service.ErrNotInstalled) {
|
||||
mainLog.Load().Warn().Msg("Service not installed")
|
||||
return nil
|
||||
}
|
||||
if status == service.StatusStopped {
|
||||
mainLog.Load().Warn().Msg("Service is not running")
|
||||
return nil
|
||||
}
|
||||
|
||||
resp, err := lc.controlClient.post(viewLogsPath, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get logs: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
switch resp.StatusCode {
|
||||
case http.StatusMovedPermanently:
|
||||
lc.warnRuntimeLoggingNotEnabled()
|
||||
return nil
|
||||
case http.StatusBadRequest:
|
||||
mainLog.Load().Warn().Msg("Runtime debug logs are not available")
|
||||
buf, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
mainLog.Load().Fatal().Err(err).Msg("Failed to read response body")
|
||||
}
|
||||
mainLog.Load().Warn().Msgf("ctrld process response:\n\n%s\n", string(buf))
|
||||
return nil
|
||||
case http.StatusOK:
|
||||
}
|
||||
|
||||
var logs logViewResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&logs); err != nil {
|
||||
return fmt.Errorf("failed to decode view logs result: %w", err)
|
||||
}
|
||||
|
||||
fmt.Print(logs.Data)
|
||||
return nil
|
||||
}
|
||||
|
||||
// InitLogCmd creates the log command with proper logic
|
||||
func InitLogCmd(rootCmd *cobra.Command) *cobra.Command {
|
||||
lc, err := NewLogCommand()
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("failed to create log command: %v", err))
|
||||
}
|
||||
|
||||
logSendCmd := &cobra.Command{
|
||||
Use: "send",
|
||||
Short: "Send runtime debug logs to ControlD",
|
||||
Args: cobra.NoArgs,
|
||||
PreRun: func(cmd *cobra.Command, args []string) {
|
||||
checkHasElevatedPrivilege()
|
||||
},
|
||||
RunE: lc.SendLogs,
|
||||
}
|
||||
|
||||
logViewCmd := &cobra.Command{
|
||||
Use: "view",
|
||||
Short: "View current runtime debug logs",
|
||||
Args: cobra.NoArgs,
|
||||
PreRun: func(cmd *cobra.Command, args []string) {
|
||||
checkHasElevatedPrivilege()
|
||||
},
|
||||
RunE: lc.ViewLogs,
|
||||
}
|
||||
|
||||
logCmd := &cobra.Command{
|
||||
Use: "log",
|
||||
Short: "Manage runtime debug logs",
|
||||
Args: cobra.OnlyValidArgs,
|
||||
ValidArgs: []string{
|
||||
logSendCmd.Use,
|
||||
logViewCmd.Use,
|
||||
},
|
||||
}
|
||||
logCmd.AddCommand(logSendCmd)
|
||||
logCmd.AddCommand(logViewCmd)
|
||||
rootCmd.AddCommand(logCmd)
|
||||
|
||||
return logCmd
|
||||
}
|
||||
59
cmd/cli/commands_run.go
Normal file
59
cmd/cli/commands_run.go
Normal file
@@ -0,0 +1,59 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld"
|
||||
)
|
||||
|
||||
// RunCommand handles run-related operations
|
||||
type RunCommand struct {
|
||||
// Add any dependencies here if needed in the future
|
||||
}
|
||||
|
||||
// NewRunCommand creates a new run command handler
|
||||
func NewRunCommand() *RunCommand {
|
||||
return &RunCommand{}
|
||||
}
|
||||
|
||||
// Run implements the logic for the run command
|
||||
func (rc *RunCommand) Run(cmd *cobra.Command, args []string) {
|
||||
RunCobraCommand(cmd)
|
||||
}
|
||||
|
||||
// InitRunCmd creates the run command with proper logic
|
||||
func InitRunCmd(rootCmd *cobra.Command) *cobra.Command {
|
||||
rc := NewRunCommand()
|
||||
|
||||
runCmd := &cobra.Command{
|
||||
Use: "run",
|
||||
Short: "Run the DNS proxy server",
|
||||
Args: cobra.NoArgs,
|
||||
Run: rc.Run,
|
||||
}
|
||||
runCmd.Flags().BoolVarP(&daemon, "daemon", "d", false, "Run as daemon")
|
||||
runCmd.Flags().StringVarP(&configPath, "config", "c", "", "Path to config file")
|
||||
runCmd.Flags().StringVarP(&configBase64, "base64_config", "", "", "Base64 encoded config")
|
||||
runCmd.Flags().StringVarP(&listenAddress, "listen", "", "", "Listener address and port, in format: address:port")
|
||||
runCmd.Flags().StringVarP(&primaryUpstream, "primary_upstream", "", "", "Primary upstream endpoint")
|
||||
runCmd.Flags().StringVarP(&secondaryUpstream, "secondary_upstream", "", "", "Secondary upstream endpoint")
|
||||
runCmd.Flags().StringSliceVarP(&domains, "domains", "", nil, "List of domain to apply in a split DNS policy")
|
||||
runCmd.Flags().StringVarP(&logPath, "log", "", "", "Path to log file")
|
||||
runCmd.Flags().IntVarP(&cacheSize, "cache_size", "", 0, "Enable cache with size items")
|
||||
runCmd.Flags().StringVarP(&cdUID, cdUidFlagName, "", "", "Control D resolver uid")
|
||||
runCmd.Flags().StringVarP(&cdOrg, cdOrgFlagName, "", "", "Control D provision token")
|
||||
runCmd.Flags().StringVarP(&customHostname, customHostnameFlagName, "", "", "Custom hostname passed to ControlD API")
|
||||
runCmd.Flags().BoolVarP(&cdDev, "dev", "", false, "Use Control D dev resolver/domain")
|
||||
_ = runCmd.Flags().MarkHidden("dev")
|
||||
runCmd.Flags().StringVarP(&homedir, "homedir", "", "", "")
|
||||
_ = runCmd.Flags().MarkHidden("homedir")
|
||||
runCmd.Flags().StringVarP(&iface, "iface", "", "", `Update DNS setting for iface, "auto" means the default interface gateway`)
|
||||
_ = runCmd.Flags().MarkHidden("iface")
|
||||
runCmd.Flags().StringVarP(&cdUpstreamProto, "proto", "", ctrld.ResolverTypeDOH, `Control D upstream type, either "doh" or "doh3"`)
|
||||
runCmd.Flags().BoolVarP(&rfc1918, "rfc1918", "", false, "Listen on RFC1918 addresses when 127.0.0.1 is the only listener")
|
||||
|
||||
runCmd.FParseErrWhitelist = cobra.FParseErrWhitelist{UnknownFlags: true}
|
||||
rootCmd.AddCommand(runCmd)
|
||||
|
||||
return runCmd
|
||||
}
|
||||
256
cmd/cli/commands_service.go
Normal file
256
cmd/cli/commands_service.go
Normal file
@@ -0,0 +1,256 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"runtime"
|
||||
|
||||
"github.com/kardianos/service"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
// filterEmptyStrings removes empty strings from a slice
|
||||
// This is used to clean up command line arguments and configuration values
|
||||
func filterEmptyStrings(slice []string) []string {
|
||||
var result []string
|
||||
for _, s := range slice {
|
||||
if s != "" {
|
||||
result = append(result, s)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// ServiceCommand handles service-related operations
|
||||
// This encapsulates all service management functionality for the CLI
|
||||
type ServiceCommand struct {
|
||||
serviceManager *ServiceManager
|
||||
}
|
||||
|
||||
// initializeServiceManager creates a service manager with default configuration
|
||||
// This sets up the basic service infrastructure needed for all service operations
|
||||
func (sc *ServiceCommand) initializeServiceManager() (service.Service, *prog, error) {
|
||||
svcConfig := sc.createServiceConfig()
|
||||
return sc.initializeServiceManagerWithServiceConfig(svcConfig)
|
||||
}
|
||||
|
||||
// initializeServiceManagerWithServiceConfig creates a service manager with the given configuration
|
||||
// This allows for custom service configuration while maintaining the same initialization pattern
|
||||
func (sc *ServiceCommand) initializeServiceManagerWithServiceConfig(svcConfig *service.Config) (service.Service, *prog, error) {
|
||||
p := &prog{}
|
||||
|
||||
s, err := sc.newService(p, svcConfig)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to create service: %w", err)
|
||||
}
|
||||
|
||||
sc.serviceManager = &ServiceManager{prog: p, svc: s}
|
||||
return s, p, nil
|
||||
}
|
||||
|
||||
// newService creates a new service instance using the provided program and configuration.
|
||||
// This abstracts the service creation process for different operating systems
|
||||
func (sc *ServiceCommand) newService(p *prog, svcConfig *service.Config) (service.Service, error) {
|
||||
s, err := newService(p, svcConfig)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create service: %w", err)
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// NewServiceCommand creates a new service command handler
|
||||
// This provides a clean factory method for creating service command instances
|
||||
func NewServiceCommand() *ServiceCommand {
|
||||
return &ServiceCommand{}
|
||||
}
|
||||
|
||||
// createServiceConfig creates a properly initialized service configuration
|
||||
// This ensures consistent service naming and description across all platforms
|
||||
func (sc *ServiceCommand) createServiceConfig() *service.Config {
|
||||
return &service.Config{
|
||||
Name: ctrldServiceName,
|
||||
DisplayName: "Control-D Helper Service",
|
||||
Description: "A highly configurable, multi-protocol DNS forwarding proxy",
|
||||
Option: service.KeyValue{},
|
||||
}
|
||||
}
|
||||
|
||||
// InitServiceCmd creates the service command with proper logic and aliases
|
||||
// This sets up all service-related subcommands with appropriate permissions and flags
|
||||
func InitServiceCmd(rootCmd *cobra.Command) *cobra.Command {
|
||||
// Create service command handlers
|
||||
sc := NewServiceCommand()
|
||||
|
||||
startCmd, startCmdAlias := createStartCommands(sc)
|
||||
rootCmd.AddCommand(startCmdAlias)
|
||||
|
||||
// Stop command
|
||||
stopCmd := &cobra.Command{
|
||||
Use: "stop",
|
||||
Short: "Stop the ctrld service",
|
||||
Args: cobra.NoArgs,
|
||||
PreRun: func(cmd *cobra.Command, args []string) {
|
||||
checkHasElevatedPrivilege()
|
||||
},
|
||||
RunE: sc.Stop,
|
||||
}
|
||||
stopCmd.Flags().StringVarP(&iface, "iface", "", "", `Reset DNS setting for iface, "auto" means the default interface gateway`)
|
||||
stopCmd.Flags().Int64VarP(&deactivationPin, "pin", "", defaultDeactivationPin, `Pin code for stopping ctrld`)
|
||||
_ = stopCmd.Flags().MarkHidden("pin")
|
||||
|
||||
// Restart command
|
||||
restartCmd := &cobra.Command{
|
||||
Use: "restart",
|
||||
Short: "Restart the ctrld service",
|
||||
Args: cobra.NoArgs,
|
||||
PreRun: func(cmd *cobra.Command, args []string) {
|
||||
checkHasElevatedPrivilege()
|
||||
},
|
||||
RunE: sc.Restart,
|
||||
}
|
||||
|
||||
// Status command
|
||||
statusCmd := &cobra.Command{
|
||||
Use: "status",
|
||||
Short: "Show status of the ctrld service",
|
||||
Args: cobra.NoArgs,
|
||||
RunE: sc.Status,
|
||||
}
|
||||
if runtime.GOOS == "darwin" {
|
||||
// On darwin, running status command without privileges may return wrong information.
|
||||
statusCmd.PreRun = func(cmd *cobra.Command, args []string) {
|
||||
checkHasElevatedPrivilege()
|
||||
}
|
||||
}
|
||||
|
||||
// Reload command
|
||||
reloadCmd := &cobra.Command{
|
||||
Use: "reload",
|
||||
Short: "Reload the ctrld service",
|
||||
Args: cobra.NoArgs,
|
||||
PreRun: func(cmd *cobra.Command, args []string) {
|
||||
checkHasElevatedPrivilege()
|
||||
},
|
||||
RunE: sc.Reload,
|
||||
}
|
||||
|
||||
// Uninstall command
|
||||
uninstallCmd := &cobra.Command{
|
||||
Use: "uninstall",
|
||||
Short: "Stop and uninstall the ctrld service",
|
||||
Long: `Stop and uninstall the ctrld service.
|
||||
|
||||
NOTE: Uninstalling will set DNS to values provided by DHCP.`,
|
||||
Args: cobra.NoArgs,
|
||||
PreRun: func(cmd *cobra.Command, args []string) {
|
||||
checkHasElevatedPrivilege()
|
||||
},
|
||||
RunE: sc.Uninstall,
|
||||
}
|
||||
uninstallCmd.Flags().StringVarP(&iface, "iface", "", "", `Reset DNS setting for iface, "auto" means the default interface gateway`)
|
||||
uninstallCmd.Flags().Int64VarP(&deactivationPin, "pin", "", defaultDeactivationPin, `Pin code for stopping ctrld`)
|
||||
_ = uninstallCmd.Flags().MarkHidden("pin")
|
||||
uninstallCmd.Flags().BoolVarP(&cleanup, "cleanup", "", false, `Removing ctrld binary and config files`)
|
||||
|
||||
// Interfaces command - use the existing InitInterfacesCmd function
|
||||
interfacesCmd := InitInterfacesCmd(rootCmd)
|
||||
|
||||
stopCmdAlias := &cobra.Command{
|
||||
PreRun: func(cmd *cobra.Command, args []string) {
|
||||
checkHasElevatedPrivilege()
|
||||
},
|
||||
Use: "stop",
|
||||
Short: "Quick stop service and remove DNS from interface",
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
if !cmd.Flags().Changed("iface") {
|
||||
os.Args = append(os.Args, "--iface="+ifaceStartStop)
|
||||
}
|
||||
iface = ifaceStartStop
|
||||
return stopCmd.RunE(cmd, args)
|
||||
},
|
||||
}
|
||||
stopCmdAlias.Flags().StringVarP(&ifaceStartStop, "iface", "", "auto", `Reset DNS setting for iface, "auto" means the default interface gateway`)
|
||||
stopCmdAlias.Flags().AddFlagSet(stopCmd.Flags())
|
||||
rootCmd.AddCommand(stopCmdAlias)
|
||||
|
||||
// Create aliases for other service commands
|
||||
restartCmdAlias := &cobra.Command{
|
||||
PreRun: func(cmd *cobra.Command, args []string) {
|
||||
checkHasElevatedPrivilege()
|
||||
},
|
||||
Use: "restart",
|
||||
Short: "Restart the ctrld service",
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
return restartCmd.RunE(cmd, args)
|
||||
},
|
||||
}
|
||||
rootCmd.AddCommand(restartCmdAlias)
|
||||
|
||||
reloadCmdAlias := &cobra.Command{
|
||||
PreRun: func(cmd *cobra.Command, args []string) {
|
||||
checkHasElevatedPrivilege()
|
||||
},
|
||||
Use: "reload",
|
||||
Short: "Reload the ctrld service",
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
return reloadCmd.RunE(cmd, args)
|
||||
},
|
||||
}
|
||||
rootCmd.AddCommand(reloadCmdAlias)
|
||||
|
||||
statusCmdAlias := &cobra.Command{
|
||||
Use: "status",
|
||||
Short: "Show status of the ctrld service",
|
||||
Args: cobra.NoArgs,
|
||||
RunE: statusCmd.RunE,
|
||||
}
|
||||
rootCmd.AddCommand(statusCmdAlias)
|
||||
|
||||
uninstallCmdAlias := &cobra.Command{
|
||||
PreRun: func(cmd *cobra.Command, args []string) {
|
||||
checkHasElevatedPrivilege()
|
||||
},
|
||||
Use: "uninstall",
|
||||
Short: "Stop and uninstall the ctrld service",
|
||||
Long: `Stop and uninstall the ctrld service.
|
||||
|
||||
NOTE: Uninstalling will set DNS to values provided by DHCP.`,
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
if !cmd.Flags().Changed("iface") {
|
||||
os.Args = append(os.Args, "--iface="+ifaceStartStop)
|
||||
}
|
||||
iface = ifaceStartStop
|
||||
return uninstallCmd.RunE(cmd, args)
|
||||
},
|
||||
}
|
||||
uninstallCmdAlias.Flags().StringVarP(&ifaceStartStop, "iface", "", "auto", `Reset DNS setting for iface, "auto" means the default interface gateway`)
|
||||
uninstallCmdAlias.Flags().AddFlagSet(uninstallCmd.Flags())
|
||||
rootCmd.AddCommand(uninstallCmdAlias)
|
||||
|
||||
// Create service command
|
||||
serviceCmd := &cobra.Command{
|
||||
Use: "service",
|
||||
Short: "Manage ctrld service",
|
||||
Args: cobra.OnlyValidArgs,
|
||||
}
|
||||
serviceCmd.ValidArgs = make([]string, 7)
|
||||
serviceCmd.ValidArgs[0] = startCmd.Use
|
||||
serviceCmd.ValidArgs[1] = stopCmd.Use
|
||||
serviceCmd.ValidArgs[2] = restartCmd.Use
|
||||
serviceCmd.ValidArgs[3] = reloadCmd.Use
|
||||
serviceCmd.ValidArgs[4] = statusCmd.Use
|
||||
serviceCmd.ValidArgs[5] = uninstallCmd.Use
|
||||
serviceCmd.ValidArgs[6] = interfacesCmd.Use
|
||||
|
||||
serviceCmd.AddCommand(startCmd)
|
||||
serviceCmd.AddCommand(stopCmd)
|
||||
serviceCmd.AddCommand(restartCmd)
|
||||
serviceCmd.AddCommand(reloadCmd)
|
||||
serviceCmd.AddCommand(statusCmd)
|
||||
serviceCmd.AddCommand(uninstallCmd)
|
||||
serviceCmd.AddCommand(interfacesCmd)
|
||||
|
||||
rootCmd.AddCommand(serviceCmd)
|
||||
|
||||
return serviceCmd
|
||||
}
|
||||
41
cmd/cli/commands_service_manager.go
Normal file
41
cmd/cli/commands_service_manager.go
Normal file
@@ -0,0 +1,41 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/kardianos/service"
|
||||
)
|
||||
|
||||
// dialSocketControlServerTimeout is the default timeout to wait when ping control server.
|
||||
const dialSocketControlServerTimeout = 30 * time.Second
|
||||
|
||||
// ServiceManager handles service operations
|
||||
type ServiceManager struct {
|
||||
prog *prog
|
||||
svc service.Service
|
||||
}
|
||||
|
||||
// NewServiceManager creates a new service manager
|
||||
func NewServiceManager() (*ServiceManager, error) {
|
||||
p := &prog{}
|
||||
|
||||
// Create a proper service configuration
|
||||
svcConfig := &service.Config{
|
||||
Name: ctrldServiceName,
|
||||
DisplayName: "Control-D Helper Service",
|
||||
Description: "A highly configurable, multi-protocol DNS forwarding proxy",
|
||||
Option: service.KeyValue{},
|
||||
}
|
||||
|
||||
s, err := newService(p, svcConfig)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create service: %w", err)
|
||||
}
|
||||
return &ServiceManager{prog: p, svc: s}, nil
|
||||
}
|
||||
|
||||
// Status returns the current service status
|
||||
func (sm *ServiceManager) Status() (service.Status, error) {
|
||||
return sm.svc.Status()
|
||||
}
|
||||
67
cmd/cli/commands_service_reload.go
Normal file
67
cmd/cli/commands_service_reload.go
Normal file
@@ -0,0 +1,67 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/kardianos/service"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
// Reload implements the logic from cmdReload.Run
|
||||
func (sc *ServiceCommand) Reload(cmd *cobra.Command, args []string) error {
|
||||
logger := mainLog.Load()
|
||||
logger.Debug().Msg("Service reload command started")
|
||||
|
||||
s, _, err := sc.initializeServiceManager()
|
||||
if err != nil {
|
||||
logger.Error().Err(err).Msg("Failed to initialize service manager")
|
||||
return err
|
||||
}
|
||||
|
||||
status, err := s.Status()
|
||||
if errors.Is(err, service.ErrNotInstalled) {
|
||||
logger.Warn().Msg("Service not installed")
|
||||
return nil
|
||||
}
|
||||
if status == service.StatusStopped {
|
||||
logger.Warn().Msg("Service is not running")
|
||||
return nil
|
||||
}
|
||||
|
||||
dir, err := socketDir()
|
||||
if err != nil {
|
||||
logger.Fatal().Err(err).Msg("Failed to find ctrld home dir")
|
||||
}
|
||||
|
||||
cc := newControlClient(filepath.Join(dir, ctrldControlUnixSock))
|
||||
resp, err := cc.post(reloadPath, nil)
|
||||
if err != nil {
|
||||
logger.Fatal().Err(err).Msg("Failed to send reload signal to ctrld")
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
switch resp.StatusCode {
|
||||
case http.StatusOK:
|
||||
logger.Notice().Msg("Service reloaded")
|
||||
case http.StatusCreated:
|
||||
logger.Warn().Msg("Service was reloaded, but new config requires service restart.")
|
||||
logger.Warn().Msg("Restarting service")
|
||||
if _, err := s.Status(); errors.Is(err, service.ErrNotInstalled) {
|
||||
logger.Warn().Msg("Service not installed")
|
||||
return nil
|
||||
}
|
||||
return sc.Restart(cmd, args)
|
||||
default:
|
||||
buf, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
logger.Fatal().Err(err).Msg("Could not read response from control server")
|
||||
}
|
||||
logger.Error().Err(err).Msgf("Failed to reload ctrld: %s", string(buf))
|
||||
}
|
||||
|
||||
logger.Debug().Msg("Service reload command completed")
|
||||
return nil
|
||||
}
|
||||
111
cmd/cli/commands_service_restart.go
Normal file
111
cmd/cli/commands_service_restart.go
Normal file
@@ -0,0 +1,111 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"github.com/kardianos/service"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
// Restart implements the logic from cmdRestart.Run
|
||||
func (sc *ServiceCommand) Restart(cmd *cobra.Command, args []string) error {
|
||||
logger := mainLog.Load()
|
||||
logger.Debug().Msg("Service restart command started")
|
||||
|
||||
readConfig(false)
|
||||
v.Unmarshal(&cfg)
|
||||
cdUID = curCdUID()
|
||||
cdMode := cdUID != ""
|
||||
|
||||
s, p, err := sc.initializeServiceManager()
|
||||
if err != nil {
|
||||
logger.Error().Err(err).Msg("Failed to initialize service manager")
|
||||
return err
|
||||
}
|
||||
|
||||
if _, err := s.Status(); errors.Is(err, service.ErrNotInstalled) {
|
||||
logger.Warn().Msg("Service not installed")
|
||||
return nil
|
||||
}
|
||||
|
||||
p.cfg = &cfg
|
||||
if iface == "" {
|
||||
iface = "auto"
|
||||
}
|
||||
p.preRun()
|
||||
if ir := runningIface(s); ir != nil {
|
||||
p.runningIface = ir.Name
|
||||
p.requiredMultiNICsConfig = ir.All
|
||||
}
|
||||
|
||||
initInteractiveLogging()
|
||||
|
||||
var validateConfigErr error
|
||||
if cdMode {
|
||||
logger.Debug().Msg("Validating ControlD remote config")
|
||||
validateConfigErr = doValidateCdRemoteConfig(cdUID, false)
|
||||
if validateConfigErr != nil {
|
||||
logger.Warn().Err(validateConfigErr).Msg("ControlD remote config validation failed")
|
||||
}
|
||||
}
|
||||
|
||||
if ir := runningIface(s); ir != nil {
|
||||
iface = ir.Name
|
||||
}
|
||||
|
||||
doRestart := func() bool {
|
||||
logger.Debug().Msg("Starting service restart sequence")
|
||||
|
||||
tasks := []task{
|
||||
{s.Stop, true, "Stop"},
|
||||
{func() error {
|
||||
// restore static DNS settings or DHCP
|
||||
p.resetDNS(false, true)
|
||||
return nil
|
||||
}, false, "Cleanup"},
|
||||
{func() error {
|
||||
time.Sleep(time.Second * 1)
|
||||
return nil
|
||||
}, false, "Waiting for service to stop"},
|
||||
}
|
||||
if !doTasks(tasks) {
|
||||
logger.Error().Msg("Service stop tasks failed")
|
||||
return false
|
||||
}
|
||||
tasks = []task{
|
||||
{s.Start, true, "Start"},
|
||||
}
|
||||
success := doTasks(tasks)
|
||||
if success {
|
||||
logger.Debug().Msg("Service restart sequence completed successfully")
|
||||
} else {
|
||||
logger.Error().Msg("Service restart sequence failed")
|
||||
}
|
||||
return success
|
||||
}
|
||||
|
||||
if doRestart() {
|
||||
if dir, err := socketDir(); err == nil {
|
||||
timeout := dialSocketControlServerTimeout
|
||||
if validateConfigErr != nil {
|
||||
timeout = 5 * time.Second
|
||||
}
|
||||
if cc := newSocketControlClientWithTimeout(context.TODO(), s, dir, timeout); cc != nil {
|
||||
_, _ = cc.post(ifacePath, nil)
|
||||
logger.Debug().Msg("Control server ping successful")
|
||||
} else {
|
||||
logger.Warn().Err(err).Msg("Service was restarted, but ctrld process may not be ready yet")
|
||||
}
|
||||
} else {
|
||||
logger.Warn().Err(err).Msg("Service was restarted, but could not ping the control server")
|
||||
}
|
||||
logger.Notice().Msg("Service restarted")
|
||||
} else {
|
||||
logger.Error().Msg("Service restart failed")
|
||||
}
|
||||
|
||||
logger.Debug().Msg("Service restart command completed")
|
||||
return nil
|
||||
}
|
||||
387
cmd/cli/commands_service_start.go
Normal file
387
cmd/cli/commands_service_start.go
Normal file
@@ -0,0 +1,387 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"github.com/kardianos/service"
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld"
|
||||
)
|
||||
|
||||
// Start implements the logic from cmdStart.Run
|
||||
func (sc *ServiceCommand) Start(cmd *cobra.Command, args []string) error {
|
||||
logger := mainLog.Load()
|
||||
logger.Debug().Msg("Service start command started")
|
||||
|
||||
checkStrFlagEmpty(cmd, cdUidFlagName)
|
||||
checkStrFlagEmpty(cmd, cdOrgFlagName)
|
||||
validateCdAndNextDNSFlags()
|
||||
|
||||
svcConfig := sc.createServiceConfig()
|
||||
osArgs := os.Args[2:]
|
||||
osArgs = filterEmptyStrings(osArgs)
|
||||
if os.Args[1] == "service" {
|
||||
osArgs = os.Args[3:]
|
||||
}
|
||||
setDependencies(svcConfig)
|
||||
svcConfig.Arguments = append([]string{"run"}, osArgs...)
|
||||
|
||||
// Initialize service manager with proper configuration
|
||||
s, p, err := sc.initializeServiceManagerWithServiceConfig(svcConfig)
|
||||
if err != nil {
|
||||
logger.Error().Err(err).Msg("Failed to initialize service manager")
|
||||
return err
|
||||
}
|
||||
|
||||
p.cfg = &cfg
|
||||
p.preRun()
|
||||
|
||||
status, err := s.Status()
|
||||
isCtrldRunning := status == service.StatusRunning
|
||||
isCtrldInstalled := !errors.Is(err, service.ErrNotInstalled)
|
||||
|
||||
// Get current running iface, if any.
|
||||
var currentIface *ifaceResponse
|
||||
|
||||
// If pin code was set, do not allow running start command.
|
||||
if isCtrldRunning {
|
||||
if err := checkDeactivationPin(s, nil); isCheckDeactivationPinErr(err) {
|
||||
logger.Error().Msg("Deactivation pin check failed")
|
||||
os.Exit(deactivationPinInvalidExitCode)
|
||||
}
|
||||
currentIface = runningIface(s)
|
||||
logger.Debug().Msgf("Current interface on start: %v", currentIface)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
reportSetDnsOk := func(sockDir string) {
|
||||
if cc := newSocketControlClient(ctx, s, sockDir); cc != nil {
|
||||
if resp, _ := cc.post(ifacePath, nil); resp != nil && resp.StatusCode == http.StatusOK {
|
||||
if iface == "auto" {
|
||||
iface = defaultIfaceName()
|
||||
}
|
||||
res := &ifaceResponse{}
|
||||
if err := json.NewDecoder(resp.Body).Decode(res); err != nil {
|
||||
logger.Warn().Err(err).Msg("Failed to get iface info")
|
||||
return
|
||||
}
|
||||
if res.OK {
|
||||
name := res.Name
|
||||
if iff, err := net.InterfaceByName(name); err == nil {
|
||||
_, _ = patchNetIfaceName(iff)
|
||||
name = iff.Name
|
||||
}
|
||||
logger := logger.With().Str("iface", name)
|
||||
logger.Debug().Msg("Setting DNS successfully")
|
||||
if res.All {
|
||||
// Log that DNS is set for other interfaces.
|
||||
withEachPhysicalInterfaces(
|
||||
name,
|
||||
"set DNS",
|
||||
func(i *net.Interface) error { return nil },
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// No config path, generating config in HOME directory.
|
||||
noConfigStart := isNoConfigStart(cmd)
|
||||
writeDefaultConfig := !noConfigStart && configBase64 == ""
|
||||
|
||||
logServerStarted := make(chan struct{})
|
||||
stopLogCh := make(chan struct{})
|
||||
ud, err := userHomeDir()
|
||||
sockDir := ud
|
||||
var logServerSocketPath string
|
||||
if err != nil {
|
||||
logger.Warn().Err(err).Msg("Failed to get user home directory")
|
||||
logger.Warn().Msg("Log server did not start")
|
||||
close(logServerStarted)
|
||||
} else {
|
||||
setWorkingDirectory(svcConfig, ud)
|
||||
if configPath == "" && writeDefaultConfig {
|
||||
defaultConfigFile = filepath.Join(ud, defaultConfigFile)
|
||||
}
|
||||
svcConfig.Arguments = append(svcConfig.Arguments, "--homedir="+ud)
|
||||
if d, err := socketDir(); err == nil {
|
||||
sockDir = d
|
||||
}
|
||||
logServerSocketPath = filepath.Join(sockDir, ctrldLogUnixSock)
|
||||
_ = os.Remove(logServerSocketPath)
|
||||
go func() {
|
||||
defer os.Remove(logServerSocketPath)
|
||||
|
||||
close(logServerStarted)
|
||||
|
||||
// Start HTTP log server
|
||||
if err := httpLogServer(logServerSocketPath, stopLogCh); err != nil && err != http.ErrServerClosed {
|
||||
logger.Warn().Err(err).Msg("Failed to serve HTTP log server")
|
||||
return
|
||||
}
|
||||
}()
|
||||
}
|
||||
<-logServerStarted
|
||||
|
||||
if !startOnly {
|
||||
startOnly = len(osArgs) == 0
|
||||
}
|
||||
// If user run "ctrld start" and ctrld is already installed, starting existing service.
|
||||
if startOnly && isCtrldInstalled {
|
||||
tryReadingConfigWithNotice(false, true)
|
||||
if err := v.Unmarshal(&cfg); err != nil {
|
||||
logger.Fatal().Msgf("Failed to unmarshal config: %v", err)
|
||||
}
|
||||
|
||||
// if already running, dont restart
|
||||
if isCtrldRunning {
|
||||
logger.Notice().Msg("Service is already running")
|
||||
return nil
|
||||
}
|
||||
|
||||
initInteractiveLogging()
|
||||
tasks := []task{
|
||||
{func() error {
|
||||
// Save current DNS so we can restore later.
|
||||
withEachPhysicalInterfaces("", "saveCurrentStaticDNS", func(i *net.Interface) error {
|
||||
if err := saveCurrentStaticDNS(i); !errors.Is(err, errSaveCurrentStaticDNSNotSupported) && err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
return nil
|
||||
}, false, "Save current DNS"},
|
||||
{func() error {
|
||||
return ConfigureWindowsServiceFailureActions(ctrldServiceName)
|
||||
}, false, "Configure service failure actions"},
|
||||
{s.Start, true, "Start"},
|
||||
{noticeWritingControlDConfig, false, "Notice writing ControlD config"},
|
||||
}
|
||||
logger.Notice().Msg("Starting existing ctrld service")
|
||||
if doTasks(tasks) {
|
||||
logger.Notice().Msg("Service started")
|
||||
sockDir, err := socketDir()
|
||||
if err != nil {
|
||||
logger.Warn().Err(err).Msg("Failed to get socket directory")
|
||||
os.Exit(1)
|
||||
}
|
||||
reportSetDnsOk(sockDir)
|
||||
} else {
|
||||
logger.Error().Err(err).Msg("Failed to start existing ctrld service")
|
||||
os.Exit(1)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
if cdUID != "" {
|
||||
_ = doValidateCdRemoteConfig(cdUID, true)
|
||||
} else if uid := cdUIDFromProvToken(); uid != "" {
|
||||
cdUID = uid
|
||||
logger.Debug().Msg("Using uid from provision token")
|
||||
removeOrgFlagsFromArgs(svcConfig)
|
||||
// Pass --cd flag to "ctrld run" command, so the provision token takes no effect.
|
||||
svcConfig.Arguments = append(svcConfig.Arguments, "--cd="+cdUID)
|
||||
}
|
||||
if cdUID != "" {
|
||||
validateCdUpstreamProtocol()
|
||||
}
|
||||
|
||||
if configPath != "" {
|
||||
v.SetConfigFile(configPath)
|
||||
}
|
||||
|
||||
tryReadingConfigWithNotice(writeDefaultConfig, true)
|
||||
|
||||
if err := v.Unmarshal(&cfg); err != nil {
|
||||
logger.Fatal().Msgf("Failed to unmarshal config: %v", err)
|
||||
}
|
||||
|
||||
initInteractiveLogging()
|
||||
|
||||
if nextdns != "" {
|
||||
removeNextDNSFromArgs(svcConfig)
|
||||
}
|
||||
|
||||
// Explicitly passing config, so on system where home directory could not be obtained,
|
||||
// or sub-process env is different with the parent, we still behave correctly and use
|
||||
// the expected config file.
|
||||
if configPath == "" {
|
||||
svcConfig.Arguments = append(svcConfig.Arguments, "--config="+defaultConfigFile)
|
||||
}
|
||||
|
||||
tasks := []task{
|
||||
{s.Stop, false, "Stop"},
|
||||
{func() error { return doGenerateNextDNSConfig(nextdns) }, true, "Checking config"},
|
||||
{func() error { return ensureUninstall(s) }, false, "Ensure uninstall"},
|
||||
//resetDnsTask(p, s, isCtrldInstalled, currentIface),
|
||||
{func() error {
|
||||
// Save current DNS so we can restore later.
|
||||
withEachPhysicalInterfaces("", "saveCurrentStaticDNS", func(i *net.Interface) error {
|
||||
if err := saveCurrentStaticDNS(i); !errors.Is(err, errSaveCurrentStaticDNSNotSupported) && err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
return nil
|
||||
}, false, "Save current DNS"},
|
||||
{s.Install, false, "Install"},
|
||||
{func() error {
|
||||
return ConfigureWindowsServiceFailureActions(ctrldServiceName)
|
||||
}, false, "Configure Windows service failure actions"},
|
||||
{s.Start, true, "Start"},
|
||||
// Note that startCmd do not actually write ControlD config, but the config file was
|
||||
// generated after s.Start, so we notice users here for consistent with nextdns mode.
|
||||
{noticeWritingControlDConfig, false, "Notice writing ControlD config"},
|
||||
}
|
||||
logger.Notice().Msg("Starting service")
|
||||
if doTasks(tasks) {
|
||||
// add a small delay to ensure the service is started and did not crash
|
||||
time.Sleep(1 * time.Second)
|
||||
|
||||
ok, status, err := selfCheckStatus(ctx, s, sockDir)
|
||||
switch {
|
||||
case ok && status == service.StatusRunning:
|
||||
logger.Notice().Msg("Service started")
|
||||
default:
|
||||
marker := append(bytes.Repeat([]byte("="), 32), '\n')
|
||||
// If ctrld service is not running, emitting log obtained from ctrld process.
|
||||
if status != service.StatusRunning || ctx.Err() != nil {
|
||||
logger.Error().Msg("Ctrld service may not have started due to an error or misconfiguration, service log:")
|
||||
_, _ = logger.Write(marker)
|
||||
|
||||
// Wait for log collection to complete
|
||||
<-stopLogCh
|
||||
|
||||
// Retrieve logs from HTTP server if available
|
||||
if logServerSocketPath != "" {
|
||||
hlc := newHTTPLogClient(logServerSocketPath)
|
||||
logs, err := hlc.GetLogs()
|
||||
if err != nil {
|
||||
logger.Warn().Err(err).Msg("Failed to get logs from HTTP log server")
|
||||
}
|
||||
if len(logs) == 0 {
|
||||
logger.Write([]byte("<no log output is obtained from ctrld process>\n"))
|
||||
} else {
|
||||
logger.Write(logs)
|
||||
logger.Write([]byte("\n"))
|
||||
}
|
||||
} else {
|
||||
logger.Write([]byte("<no log output from HTTP log server>\n"))
|
||||
}
|
||||
}
|
||||
// Report any error if occurred.
|
||||
if err != nil {
|
||||
_, _ = logger.Write(marker)
|
||||
msg := fmt.Sprintf("An error occurred while performing test query: %s\n", err)
|
||||
logger.Write([]byte(msg))
|
||||
}
|
||||
// If ctrld service is running but selfCheckStatus failed, it could be related
|
||||
// to user's system firewall configuration, notice users about it.
|
||||
if status == service.StatusRunning && err == nil {
|
||||
_, _ = logger.Write(marker)
|
||||
logger.Write([]byte("ctrld service was running, but a DNS query could not be sent to its listener\n"))
|
||||
logger.Write([]byte("Please check your system firewall if it is configured to block/intercept/redirect DNS queries\n"))
|
||||
}
|
||||
|
||||
_, _ = logger.Write(marker)
|
||||
uninstall(p, s)
|
||||
os.Exit(1)
|
||||
}
|
||||
reportSetDnsOk(sockDir)
|
||||
}
|
||||
|
||||
logger.Debug().Msg("Service start command completed")
|
||||
return nil
|
||||
}
|
||||
|
||||
// createStartCommands creates the start command and its alias
|
||||
func createStartCommands(sc *ServiceCommand) (*cobra.Command, *cobra.Command) {
|
||||
// Start command
|
||||
startCmd := &cobra.Command{
|
||||
Use: "start",
|
||||
Short: "Install and start the ctrld service",
|
||||
Long: `Install and start the ctrld service
|
||||
|
||||
NOTE: running "ctrld start" without any arguments will start already installed ctrld service.`,
|
||||
Args: func(cmd *cobra.Command, args []string) error {
|
||||
args = filterEmptyStrings(args)
|
||||
if len(args) > 0 {
|
||||
return fmt.Errorf("'ctrld start' doesn't accept positional arguments\n" +
|
||||
"Use flags instead (e.g. --cd, --iface) or see 'ctrld start --help' for all options")
|
||||
}
|
||||
return nil
|
||||
},
|
||||
PreRun: func(cmd *cobra.Command, args []string) {
|
||||
checkHasElevatedPrivilege()
|
||||
},
|
||||
RunE: sc.Start,
|
||||
}
|
||||
// Keep these flags in sync with runCmd above, except for "-d"/"--nextdns".
|
||||
startCmd.Flags().StringVarP(&configPath, "config", "c", "", "Path to config file")
|
||||
startCmd.Flags().StringVarP(&configBase64, "base64_config", "", "", "Base64 encoded config")
|
||||
startCmd.Flags().StringVarP(&listenAddress, "listen", "", "", "Listener address and port, in format: address:port")
|
||||
startCmd.Flags().StringVarP(&primaryUpstream, "primary_upstream", "", "", "Primary upstream endpoint")
|
||||
startCmd.Flags().StringVarP(&secondaryUpstream, "secondary_upstream", "", "", "Secondary upstream endpoint")
|
||||
startCmd.Flags().StringSliceVarP(&domains, "domains", "", nil, "List of domain to apply in a split DNS policy")
|
||||
startCmd.Flags().StringVarP(&logPath, "log", "", "", "Path to log file")
|
||||
startCmd.Flags().IntVarP(&cacheSize, "cache_size", "", 0, "Enable cache with size items")
|
||||
startCmd.Flags().StringVarP(&cdUID, cdUidFlagName, "", "", "Control D resolver uid")
|
||||
startCmd.Flags().StringVarP(&cdOrg, cdOrgFlagName, "", "", "Control D provision token")
|
||||
startCmd.Flags().StringVarP(&customHostname, customHostnameFlagName, "", "", "Custom hostname passed to ControlD API")
|
||||
startCmd.Flags().BoolVarP(&cdDev, "dev", "", false, "Use Control D dev resolver/domain")
|
||||
_ = startCmd.Flags().MarkHidden("dev")
|
||||
startCmd.Flags().StringVarP(&iface, "iface", "", "", `Update DNS setting for iface, "auto" means the default interface gateway`)
|
||||
startCmd.Flags().StringVarP(&nextdns, nextdnsFlagName, "", "", "NextDNS resolver id")
|
||||
startCmd.Flags().StringVarP(&cdUpstreamProto, "proto", "", ctrld.ResolverTypeDOH, `Control D upstream type, either "doh" or "doh3"`)
|
||||
startCmd.Flags().BoolVarP(&skipSelfChecks, "skip_self_checks", "", false, `Skip self checks after installing ctrld service`)
|
||||
startCmd.Flags().BoolVarP(&startOnly, "start_only", "", false, "Do not install new service")
|
||||
_ = startCmd.Flags().MarkHidden("start_only")
|
||||
startCmd.Flags().BoolVarP(&rfc1918, "rfc1918", "", false, "Listen on RFC1918 addresses when 127.0.0.1 is the only listener")
|
||||
|
||||
// Start command alias
|
||||
startCmdAlias := &cobra.Command{
|
||||
PreRun: func(cmd *cobra.Command, args []string) {
|
||||
checkHasElevatedPrivilege()
|
||||
},
|
||||
Use: "start",
|
||||
Short: "Quick start service and configure DNS on interface",
|
||||
Long: `Quick start service and configure DNS on interface
|
||||
|
||||
NOTE: running "ctrld start" without any arguments will start already installed ctrld service.`,
|
||||
Args: func(cmd *cobra.Command, args []string) error {
|
||||
args = filterEmptyStrings(args)
|
||||
if len(args) > 0 {
|
||||
return fmt.Errorf("'ctrld start' doesn't accept positional arguments\n" +
|
||||
"Use flags instead (e.g. --cd, --iface) or see 'ctrld start --help' for all options")
|
||||
}
|
||||
return nil
|
||||
},
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
if len(os.Args) == 2 {
|
||||
startOnly = true
|
||||
}
|
||||
if !cmd.Flags().Changed("iface") {
|
||||
os.Args = append(os.Args, "--iface="+ifaceStartStop)
|
||||
}
|
||||
iface = ifaceStartStop
|
||||
return startCmd.RunE(cmd, args)
|
||||
},
|
||||
}
|
||||
startCmdAlias.Flags().StringVarP(&ifaceStartStop, "iface", "", "auto", `Update DNS setting for iface, "auto" means the default interface gateway`)
|
||||
startCmdAlias.Flags().AddFlagSet(startCmd.Flags())
|
||||
|
||||
return startCmd, startCmdAlias
|
||||
}
|
||||
41
cmd/cli/commands_service_status.go
Normal file
41
cmd/cli/commands_service_status.go
Normal file
@@ -0,0 +1,41 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"os"
|
||||
|
||||
"github.com/kardianos/service"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
// Status implements the logic from cmdStatus.Run
|
||||
func (sc *ServiceCommand) Status(cmd *cobra.Command, args []string) error {
|
||||
logger := mainLog.Load()
|
||||
logger.Debug().Msg("Service status command started")
|
||||
|
||||
s, _, err := sc.initializeServiceManager()
|
||||
if err != nil {
|
||||
logger.Error().Err(err).Msg("Failed to initialize service manager")
|
||||
return err
|
||||
}
|
||||
|
||||
status, err := s.Status()
|
||||
if err != nil {
|
||||
logger.Error().Msg(err.Error())
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
switch status {
|
||||
case service.StatusUnknown:
|
||||
logger.Notice().Msg("Unknown status")
|
||||
os.Exit(2)
|
||||
case service.StatusRunning:
|
||||
logger.Notice().Msg("Service is running")
|
||||
os.Exit(0)
|
||||
case service.StatusStopped:
|
||||
logger.Notice().Msg("Service is stopped")
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
logger.Debug().Msg("Service status command completed")
|
||||
return nil
|
||||
}
|
||||
61
cmd/cli/commands_service_stop.go
Normal file
61
cmd/cli/commands_service_stop.go
Normal file
@@ -0,0 +1,61 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"os"
|
||||
|
||||
"github.com/kardianos/service"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
// Stop implements the logic from cmdStop.Run
|
||||
func (sc *ServiceCommand) Stop(cmd *cobra.Command, args []string) error {
|
||||
logger := mainLog.Load()
|
||||
logger.Debug().Msg("Service stop command started")
|
||||
|
||||
readConfig(false)
|
||||
v.Unmarshal(&cfg)
|
||||
|
||||
s, p, err := sc.initializeServiceManager()
|
||||
if err != nil {
|
||||
logger.Error().Err(err).Msg("Failed to initialize service manager")
|
||||
return err
|
||||
}
|
||||
|
||||
p.cfg = &cfg
|
||||
if iface == "" {
|
||||
iface = "auto"
|
||||
}
|
||||
p.preRun()
|
||||
if ir := runningIface(s); ir != nil {
|
||||
p.runningIface = ir.Name
|
||||
p.requiredMultiNICsConfig = ir.All
|
||||
}
|
||||
|
||||
initInteractiveLogging()
|
||||
|
||||
status, err := s.Status()
|
||||
if errors.Is(err, service.ErrNotInstalled) {
|
||||
logger.Warn().Msg("Service not installed")
|
||||
return nil
|
||||
}
|
||||
if status == service.StatusStopped {
|
||||
logger.Warn().Msg("Service is already stopped")
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := checkDeactivationPin(s, nil); isCheckDeactivationPinErr(err) {
|
||||
logger.Error().Msg("Deactivation pin check failed")
|
||||
os.Exit(deactivationPinInvalidExitCode)
|
||||
}
|
||||
|
||||
logger.Debug().Msg("Stopping service")
|
||||
if doTasks([]task{{s.Stop, true, "Stop"}}) {
|
||||
logger.Notice().Msg("Service stopped")
|
||||
} else {
|
||||
logger.Error().Msg("Service stop failed")
|
||||
}
|
||||
|
||||
logger.Debug().Msg("Service stop command completed")
|
||||
return nil
|
||||
}
|
||||
106
cmd/cli/commands_service_uninstall.go
Normal file
106
cmd/cli/commands_service_uninstall.go
Normal file
@@ -0,0 +1,106 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"net"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld"
|
||||
)
|
||||
|
||||
// Uninstall implements the logic from cmdUninstall.Run
|
||||
func (sc *ServiceCommand) Uninstall(cmd *cobra.Command, args []string) error {
|
||||
logger := mainLog.Load()
|
||||
logger.Debug().Msg("Service uninstall command started")
|
||||
|
||||
readConfig(false)
|
||||
v.Unmarshal(&cfg)
|
||||
|
||||
s, p, err := sc.initializeServiceManager()
|
||||
if err != nil {
|
||||
logger.Error().Err(err).Msg("Failed to initialize service manager")
|
||||
return err
|
||||
}
|
||||
|
||||
p.cfg = &cfg
|
||||
if iface == "" {
|
||||
iface = "auto"
|
||||
}
|
||||
p.preRun()
|
||||
if ir := runningIface(s); ir != nil {
|
||||
p.runningIface = ir.Name
|
||||
p.requiredMultiNICsConfig = ir.All
|
||||
}
|
||||
|
||||
if err := checkDeactivationPin(s, nil); isCheckDeactivationPinErr(err) {
|
||||
logger.Error().Msg("Deactivation pin check failed")
|
||||
os.Exit(deactivationPinInvalidExitCode)
|
||||
}
|
||||
|
||||
logger.Debug().Msg("Starting service uninstall")
|
||||
uninstall(p, s)
|
||||
|
||||
if cleanup {
|
||||
logger.Debug().Msg("Performing cleanup operations")
|
||||
var files []string
|
||||
// Config file.
|
||||
files = append(files, v.ConfigFileUsed())
|
||||
// Log file and backup log file.
|
||||
// For safety, only process if log file path is absolute.
|
||||
if logFile := normalizeLogFilePath(cfg.Service.LogPath); filepath.IsAbs(logFile) {
|
||||
files = append(files, logFile)
|
||||
oldLogFile := logFile + oldLogSuffix
|
||||
if _, err := os.Stat(oldLogFile); err == nil {
|
||||
files = append(files, oldLogFile)
|
||||
}
|
||||
}
|
||||
// Socket files.
|
||||
if dir, _ := socketDir(); dir != "" {
|
||||
files = append(files, filepath.Join(dir, ctrldControlUnixSock))
|
||||
files = append(files, filepath.Join(dir, ctrldLogUnixSock))
|
||||
}
|
||||
// Static DNS settings files.
|
||||
withEachPhysicalInterfaces("", "", func(i *net.Interface) error {
|
||||
file := ctrld.SavedStaticDnsSettingsFilePath(i)
|
||||
files = append(files, file)
|
||||
return nil
|
||||
})
|
||||
bin, err := os.Executable()
|
||||
if err != nil {
|
||||
logger.Warn().Err(err).Msg("Failed to get executable path")
|
||||
}
|
||||
if bin != "" && supportedSelfDelete {
|
||||
files = append(files, bin)
|
||||
}
|
||||
// Backup file after upgrading.
|
||||
oldBin := bin + oldBinSuffix
|
||||
if _, err := os.Stat(oldBin); err == nil {
|
||||
files = append(files, oldBin)
|
||||
}
|
||||
for _, file := range files {
|
||||
if file == "" {
|
||||
continue
|
||||
}
|
||||
if err := os.Remove(file); err == nil {
|
||||
logger.Notice().Str("file", file).Msg("File removed during cleanup")
|
||||
} else {
|
||||
logger.Debug().Err(err).Str("file", file).Msg("Failed to remove file during cleanup")
|
||||
}
|
||||
}
|
||||
// Self-delete the ctrld binary if supported
|
||||
if err := selfDeleteExe(); err != nil {
|
||||
logger.Warn().Err(err).Msg("Failed to delete ctrld binary")
|
||||
} else {
|
||||
if !supportedSelfDelete {
|
||||
logger.Debug().Msgf("File removed: %s", bin)
|
||||
}
|
||||
}
|
||||
|
||||
logger.Debug().Msg("Cleanup operations completed")
|
||||
}
|
||||
|
||||
logger.Debug().Msg("Service uninstall command completed")
|
||||
return nil
|
||||
}
|
||||
197
cmd/cli/commands_test.go
Normal file
197
cmd/cli/commands_test.go
Normal file
@@ -0,0 +1,197 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"testing"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestBasicCommandStructure tests the actual root command structure
|
||||
func TestBasicCommandStructure(t *testing.T) {
|
||||
// Test the actual root command that's returned from initCLI()
|
||||
rootCmd := initCLI()
|
||||
|
||||
// Test that root command has basic properties
|
||||
assert.Equal(t, "ctrld", rootCmd.Use)
|
||||
assert.NotEmpty(t, rootCmd.Short, "Root command should have a short description")
|
||||
|
||||
// Test that root command has subcommands
|
||||
commands := rootCmd.Commands()
|
||||
assert.NotNil(t, commands, "Root command should have subcommands")
|
||||
assert.Greater(t, len(commands), 0, "Root command should have at least one subcommand")
|
||||
|
||||
// Test that expected commands exist
|
||||
expectedCommands := []string{"run", "service", "clients", "upgrade", "log"}
|
||||
for _, cmdName := range expectedCommands {
|
||||
found := false
|
||||
for _, cmd := range commands {
|
||||
if cmd.Name() == cmdName {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
assert.True(t, found, "Expected command %s not found in root command", cmdName)
|
||||
}
|
||||
}
|
||||
|
||||
// TestServiceCommandCreation tests service command creation
|
||||
func TestServiceCommandCreation(t *testing.T) {
|
||||
sc := NewServiceCommand()
|
||||
require.NotNil(t, sc, "ServiceCommand should be created")
|
||||
|
||||
// Test service config creation
|
||||
config := sc.createServiceConfig()
|
||||
require.NotNil(t, config, "Service config should be created")
|
||||
assert.Equal(t, ctrldServiceName, config.Name)
|
||||
assert.Equal(t, "Control-D Helper Service", config.DisplayName)
|
||||
assert.Equal(t, "A highly configurable, multi-protocol DNS forwarding proxy", config.Description)
|
||||
}
|
||||
|
||||
// TestServiceCommandSubCommands tests service command sub commands
|
||||
func TestServiceCommandSubCommands(t *testing.T) {
|
||||
rootCmd := &cobra.Command{
|
||||
Use: "ctrld",
|
||||
Short: "DNS forwarding proxy",
|
||||
}
|
||||
|
||||
serviceCmd := InitServiceCmd(rootCmd)
|
||||
require.NotNil(t, serviceCmd, "Service command should be created")
|
||||
|
||||
// Test that service command has subcommands
|
||||
subcommands := serviceCmd.Commands()
|
||||
assert.Greater(t, len(subcommands), 0, "Service command should have subcommands")
|
||||
|
||||
// Test specific subcommands exist
|
||||
expectedCommands := []string{"start", "stop", "restart", "reload", "status", "uninstall", "interfaces"}
|
||||
|
||||
for _, cmdName := range expectedCommands {
|
||||
found := false
|
||||
for _, cmd := range subcommands {
|
||||
if cmd.Name() == cmdName {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
assert.True(t, found, "Expected service subcommand %s not found", cmdName)
|
||||
}
|
||||
}
|
||||
|
||||
// TestCommandHelp tests basic help functionality
|
||||
func TestCommandHelp(t *testing.T) {
|
||||
// Initialize the CLI to set up the root command
|
||||
rootCmd := initCLI()
|
||||
|
||||
// Test help command execution
|
||||
var buf bytes.Buffer
|
||||
rootCmd.SetOut(&buf)
|
||||
rootCmd.SetErr(&buf)
|
||||
|
||||
rootCmd.SetArgs([]string{"--help"})
|
||||
err := rootCmd.Execute()
|
||||
assert.NoError(t, err, "Help command should execute without error")
|
||||
assert.Contains(t, buf.String(), "dns forwarding proxy", "Help output should contain description")
|
||||
}
|
||||
|
||||
// TestCommandVersion tests version command
|
||||
func TestCommandVersion(t *testing.T) {
|
||||
// Initialize the CLI to set up the root command
|
||||
rootCmd := initCLI()
|
||||
|
||||
var buf bytes.Buffer
|
||||
rootCmd.SetOut(&buf)
|
||||
rootCmd.SetErr(&buf)
|
||||
|
||||
// Test version command
|
||||
rootCmd.SetArgs([]string{"--version"})
|
||||
err := rootCmd.Execute()
|
||||
assert.NoError(t, err, "Version command should execute without error")
|
||||
assert.Contains(t, buf.String(), "version", "Version output should contain version information")
|
||||
}
|
||||
|
||||
// TestCommandErrorHandling tests error handling
|
||||
func TestCommandErrorHandling(t *testing.T) {
|
||||
// Initialize the CLI to set up the root command
|
||||
rootCmd := initCLI()
|
||||
|
||||
// Test invalid flag instead of invalid command
|
||||
rootCmd.SetArgs([]string{"--invalid-flag"})
|
||||
err := rootCmd.Execute()
|
||||
assert.Error(t, err, "Invalid flag should return error")
|
||||
}
|
||||
|
||||
// TestCommandFlags tests flag functionality
|
||||
func TestCommandFlags(t *testing.T) {
|
||||
// Initialize the CLI to set up the root command
|
||||
rootCmd := initCLI()
|
||||
|
||||
// Test that root command has expected flags
|
||||
verboseFlag := rootCmd.PersistentFlags().Lookup("verbose")
|
||||
assert.NotNil(t, verboseFlag, "Verbose flag should exist")
|
||||
assert.Equal(t, "v", verboseFlag.Shorthand)
|
||||
|
||||
silentFlag := rootCmd.PersistentFlags().Lookup("silent")
|
||||
assert.NotNil(t, silentFlag, "Silent flag should exist")
|
||||
assert.Equal(t, "s", silentFlag.Shorthand)
|
||||
}
|
||||
|
||||
// TestCommandExecution tests basic command execution
|
||||
func TestCommandExecution(t *testing.T) {
|
||||
// Initialize the CLI to set up the root command
|
||||
rootCmd := initCLI()
|
||||
|
||||
// Test that root command can be executed (help command)
|
||||
var buf bytes.Buffer
|
||||
rootCmd.SetOut(&buf)
|
||||
rootCmd.SetErr(&buf)
|
||||
|
||||
rootCmd.SetArgs([]string{"--help"})
|
||||
err := rootCmd.Execute()
|
||||
assert.NoError(t, err, "Root command should execute without error")
|
||||
assert.Contains(t, buf.String(), "dns forwarding proxy", "Help output should contain description")
|
||||
}
|
||||
|
||||
// TestCommandArgs tests argument handling
|
||||
func TestCommandArgs(t *testing.T) {
|
||||
// Initialize the CLI to set up the root command
|
||||
rootCmd := initCLI()
|
||||
|
||||
// Test that root command can handle arguments properly
|
||||
// Test with no args (should succeed)
|
||||
err := rootCmd.Execute()
|
||||
assert.NoError(t, err, "Root command with no args should execute")
|
||||
|
||||
// Test with help flag (should succeed)
|
||||
rootCmd.SetArgs([]string{"--help"})
|
||||
err = rootCmd.Execute()
|
||||
assert.NoError(t, err, "Root command with help flag should execute")
|
||||
}
|
||||
|
||||
// TestCommandSubcommands tests subcommand functionality
|
||||
func TestCommandSubcommands(t *testing.T) {
|
||||
// Initialize the CLI to set up the root command
|
||||
rootCmd := initCLI()
|
||||
|
||||
// Test that root command has subcommands
|
||||
commands := rootCmd.Commands()
|
||||
assert.Greater(t, len(commands), 0, "Root command should have subcommands")
|
||||
|
||||
// Test that specific subcommands exist and can be executed
|
||||
expectedSubcommands := []string{"run", "service", "clients", "upgrade", "log"}
|
||||
for _, subCmdName := range expectedSubcommands {
|
||||
// Find the subcommand
|
||||
var subCmd *cobra.Command
|
||||
for _, cmd := range commands {
|
||||
if cmd.Name() == subCmdName {
|
||||
subCmd = cmd
|
||||
break
|
||||
}
|
||||
}
|
||||
assert.NotNil(t, subCmd, "Subcommand %s should exist", subCmdName)
|
||||
|
||||
// Test that subcommand has help
|
||||
assert.NotEmpty(t, subCmd.Short, "Subcommand %s should have a short description", subCmdName)
|
||||
}
|
||||
}
|
||||
192
cmd/cli/commands_upgrade.go
Normal file
192
cmd/cli/commands_upgrade.go
Normal file
@@ -0,0 +1,192 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/kardianos/service"
|
||||
"github.com/minio/selfupdate"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
const (
|
||||
upgradeChannelDev = "dev"
|
||||
upgradeChannelProd = "prod"
|
||||
upgradeChannelDefault = "default"
|
||||
)
|
||||
|
||||
// UpgradeCommand handles upgrade-related operations
|
||||
type UpgradeCommand struct {
|
||||
}
|
||||
|
||||
// NewUpgradeCommand creates a new upgrade command handler
|
||||
func NewUpgradeCommand() (*UpgradeCommand, error) {
|
||||
return &UpgradeCommand{}, nil
|
||||
}
|
||||
|
||||
// Upgrade performs the upgrade operation
|
||||
func (uc *UpgradeCommand) Upgrade(cmd *cobra.Command, args []string) error {
|
||||
upgradeChannel := map[string]string{
|
||||
upgradeChannelDefault: "https://dl.controld.dev",
|
||||
upgradeChannelDev: "https://dl.controld.dev",
|
||||
upgradeChannelProd: "https://dl.controld.com",
|
||||
}
|
||||
if isStableVersion(curVersion()) {
|
||||
upgradeChannel[upgradeChannelDefault] = upgradeChannel[upgradeChannelProd]
|
||||
}
|
||||
|
||||
bin, err := os.Executable()
|
||||
if err != nil {
|
||||
mainLog.Load().Fatal().Err(err).Msg("Failed to get current ctrld binary path")
|
||||
}
|
||||
|
||||
readConfig(false)
|
||||
v.Unmarshal(&cfg)
|
||||
svcCmd := NewServiceCommand()
|
||||
s, p, err := svcCmd.initializeServiceManager()
|
||||
if err != nil {
|
||||
mainLog.Load().Error().Msg(err.Error())
|
||||
return nil
|
||||
}
|
||||
|
||||
if iface == "" {
|
||||
iface = "auto"
|
||||
}
|
||||
p.preRun()
|
||||
if ir := runningIface(s); ir != nil {
|
||||
p.runningIface = ir.Name
|
||||
p.requiredMultiNICsConfig = ir.All
|
||||
}
|
||||
|
||||
svcInstalled := true
|
||||
if _, err := s.Status(); errors.Is(err, service.ErrNotInstalled) {
|
||||
svcInstalled = false
|
||||
}
|
||||
|
||||
oldBin := bin + oldBinSuffix
|
||||
baseUrl := upgradeChannel[upgradeChannelDefault]
|
||||
if len(args) > 0 {
|
||||
channel := args[0]
|
||||
switch channel {
|
||||
case upgradeChannelProd, upgradeChannelDev: // ok
|
||||
default:
|
||||
mainLog.Load().Fatal().Msgf("Upgrade argument must be either %q or %q", upgradeChannelProd, upgradeChannelDev)
|
||||
}
|
||||
baseUrl = upgradeChannel[channel]
|
||||
}
|
||||
|
||||
dlUrl := upgradeUrl(baseUrl)
|
||||
mainLog.Load().Debug().Msgf("Downloading binary: %s", dlUrl)
|
||||
|
||||
resp, err := getWithRetry(dlUrl, downloadServerIp)
|
||||
if err != nil {
|
||||
mainLog.Load().Fatal().Err(err).Msg("Failed to download binary")
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
mainLog.Load().Fatal().Msgf("Could not download binary: %s", http.StatusText(resp.StatusCode))
|
||||
}
|
||||
|
||||
mainLog.Load().Debug().Msg("Updating current binary")
|
||||
if err := selfupdate.Apply(resp.Body, selfupdate.Options{OldSavePath: oldBin}); err != nil {
|
||||
if rerr := selfupdate.RollbackError(err); rerr != nil {
|
||||
mainLog.Load().Error().Err(rerr).Msg("Could not rollback old binary")
|
||||
}
|
||||
mainLog.Load().Fatal().Err(err).Msg("Failed to update current binary")
|
||||
}
|
||||
|
||||
doRestart := func() bool {
|
||||
if !svcInstalled {
|
||||
return true
|
||||
}
|
||||
tasks := []task{
|
||||
{s.Stop, true, "Stop"},
|
||||
{func() error {
|
||||
// restore static DNS settings or DHCP
|
||||
p.resetDNS(false, true)
|
||||
return nil
|
||||
}, false, "Cleanup"},
|
||||
{func() error {
|
||||
time.Sleep(time.Second * 1)
|
||||
return nil
|
||||
}, false, "Waiting for service to stop"},
|
||||
}
|
||||
doTasks(tasks)
|
||||
|
||||
tasks = []task{
|
||||
{s.Start, true, "Start"},
|
||||
}
|
||||
if doTasks(tasks) {
|
||||
if dir, err := socketDir(); err == nil {
|
||||
if cc := newSocketControlClient(context.TODO(), s, dir); cc != nil {
|
||||
_, _ = cc.post(ifacePath, nil)
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
if svcInstalled {
|
||||
mainLog.Load().Debug().Msg("Restarting ctrld service using new binary")
|
||||
}
|
||||
|
||||
if doRestart() {
|
||||
_ = os.Remove(oldBin)
|
||||
_ = os.Chmod(bin, 0755)
|
||||
ver := "unknown version"
|
||||
out, err := exec.Command(bin, "--version").CombinedOutput()
|
||||
if err != nil {
|
||||
mainLog.Load().Warn().Err(err).Msg("Failed to get new binary version")
|
||||
}
|
||||
if after, found := strings.CutPrefix(string(out), "ctrld version "); found {
|
||||
ver = after
|
||||
}
|
||||
mainLog.Load().Notice().Msgf("Upgrade successful - %s", ver)
|
||||
return nil
|
||||
}
|
||||
|
||||
mainLog.Load().Warn().Msgf("Upgrade failed, restoring previous binary: %s", oldBin)
|
||||
if err := os.Remove(bin); err != nil {
|
||||
mainLog.Load().Fatal().Err(err).Msg("Failed to remove new binary")
|
||||
}
|
||||
if err := os.Rename(oldBin, bin); err != nil {
|
||||
mainLog.Load().Fatal().Err(err).Msg("Failed to restore old binary")
|
||||
}
|
||||
if doRestart() {
|
||||
mainLog.Load().Notice().Msg("Restored previous binary successfully")
|
||||
return nil
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// InitUpgradeCmd creates the upgrade command with proper logic
|
||||
func InitUpgradeCmd(rootCmd *cobra.Command) *cobra.Command {
|
||||
upgradeCmd := &cobra.Command{
|
||||
Use: "upgrade",
|
||||
Short: "Upgrading ctrld to latest version",
|
||||
ValidArgs: []string{upgradeChannelDev, upgradeChannelProd},
|
||||
Args: cobra.MaximumNArgs(1),
|
||||
PreRun: func(cmd *cobra.Command, args []string) {
|
||||
checkHasElevatedPrivilege()
|
||||
},
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
uc, err := NewUpgradeCommand()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return uc.Upgrade(cmd, args)
|
||||
},
|
||||
}
|
||||
|
||||
rootCmd.AddCommand(upgradeCmd)
|
||||
|
||||
return upgradeCmd
|
||||
}
|
||||
@@ -1,51 +0,0 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"net"
|
||||
"time"
|
||||
)
|
||||
|
||||
// logConn wraps a net.Conn, override the Write behavior.
|
||||
// runCmd uses this wrapper, so as long as startCmd finished,
|
||||
// ctrld log won't be flushed with un-necessary write errors.
|
||||
type logConn struct {
|
||||
conn net.Conn
|
||||
}
|
||||
|
||||
func (lc *logConn) Read(b []byte) (n int, err error) {
|
||||
return lc.conn.Read(b)
|
||||
}
|
||||
|
||||
func (lc *logConn) Close() error {
|
||||
return lc.conn.Close()
|
||||
}
|
||||
|
||||
func (lc *logConn) LocalAddr() net.Addr {
|
||||
return lc.conn.LocalAddr()
|
||||
}
|
||||
|
||||
func (lc *logConn) RemoteAddr() net.Addr {
|
||||
return lc.conn.RemoteAddr()
|
||||
}
|
||||
|
||||
func (lc *logConn) SetDeadline(t time.Time) error {
|
||||
return lc.conn.SetDeadline(t)
|
||||
}
|
||||
|
||||
func (lc *logConn) SetReadDeadline(t time.Time) error {
|
||||
return lc.conn.SetReadDeadline(t)
|
||||
}
|
||||
|
||||
func (lc *logConn) SetWriteDeadline(t time.Time) error {
|
||||
return lc.conn.SetWriteDeadline(t)
|
||||
}
|
||||
|
||||
func (lc *logConn) Write(b []byte) (int, error) {
|
||||
// Write performs writes with underlying net.Conn, ignore any errors happen.
|
||||
// "ctrld run" command use this wrapper to report errors to "ctrld start".
|
||||
// If no error occurred, "ctrld start" may finish before "ctrld run" attempt
|
||||
// to close the connection, so ignore errors conservatively here, prevent
|
||||
// un-necessary error "write to closed connection" flushed to ctrld log.
|
||||
_, _ = lc.conn.Write(b)
|
||||
return len(b), nil
|
||||
}
|
||||
@@ -8,10 +8,12 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// controlClient represents an HTTP client for communicating with the control server
|
||||
type controlClient struct {
|
||||
c *http.Client
|
||||
}
|
||||
|
||||
// newControlClient creates a new control client with Unix socket transport
|
||||
func newControlClient(addr string) *controlClient {
|
||||
return &controlClient{c: &http.Client{
|
||||
Transport: &http.Transport{
|
||||
@@ -25,6 +27,10 @@ func newControlClient(addr string) *controlClient {
|
||||
}
|
||||
|
||||
func (c *controlClient) post(path string, data io.Reader) (*http.Response, error) {
|
||||
// for log/send, set the timeout to 5 minutes
|
||||
if path == sendLogsPath {
|
||||
c.c.Timeout = time.Minute * 5
|
||||
}
|
||||
return c.c.Post("http://unix"+path, contentTypeJson, data)
|
||||
}
|
||||
|
||||
|
||||
@@ -3,6 +3,8 @@ package cli
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
@@ -11,10 +13,10 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/kardianos/service"
|
||||
|
||||
dto "github.com/prometheus/client_model/go"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld"
|
||||
"github.com/Control-D-Inc/ctrld/internal/controld"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -25,14 +27,24 @@ const (
|
||||
deactivationPath = "/deactivation"
|
||||
cdPath = "/cd"
|
||||
ifacePath = "/iface"
|
||||
viewLogsPath = "/log/view"
|
||||
sendLogsPath = "/log/send"
|
||||
)
|
||||
|
||||
type ifaceResponse struct {
|
||||
Name string `json:"name"`
|
||||
All bool `json:"all"`
|
||||
OK bool `json:"ok"`
|
||||
}
|
||||
|
||||
// controlServer represents an HTTP server for handling control requests
|
||||
type controlServer struct {
|
||||
server *http.Server
|
||||
mux *http.ServeMux
|
||||
addr string
|
||||
}
|
||||
|
||||
// newControlServer creates a new control server instance
|
||||
func newControlServer(addr string) (*controlServer, error) {
|
||||
mux := http.NewServeMux()
|
||||
s := &controlServer{
|
||||
@@ -69,33 +81,81 @@ func (s *controlServer) register(pattern string, handler http.Handler) {
|
||||
|
||||
func (p *prog) registerControlServerHandler() {
|
||||
p.cs.register(listClientsPath, http.HandlerFunc(func(w http.ResponseWriter, request *http.Request) {
|
||||
p.Debug().Msg("Handling list clients request")
|
||||
|
||||
clients := p.ciTable.ListClients()
|
||||
p.Debug().Int("client_count", len(clients)).Msg("Retrieved clients list")
|
||||
|
||||
sort.Slice(clients, func(i, j int) bool {
|
||||
return clients[i].IP.Less(clients[j].IP)
|
||||
})
|
||||
p.Debug().Msg("Sorted clients by IP address")
|
||||
|
||||
if p.metricsQueryStats.Load() {
|
||||
for _, client := range clients {
|
||||
p.Debug().Msg("Metrics query stats enabled, collecting query counts")
|
||||
|
||||
for idx, client := range clients {
|
||||
p.Debug().
|
||||
Int("index", idx).
|
||||
Str("ip", client.IP.String()).
|
||||
Str("mac", client.Mac).
|
||||
Str("hostname", client.Hostname).
|
||||
Msg("Processing client metrics")
|
||||
|
||||
client.IncludeQueryCount = true
|
||||
dm := &dto.Metric{}
|
||||
|
||||
if statsClientQueriesCount.MetricVec == nil {
|
||||
p.Debug().
|
||||
Str("client_ip", client.IP.String()).
|
||||
Msg("Skipping metrics collection: MetricVec is nil")
|
||||
continue
|
||||
}
|
||||
|
||||
m, err := statsClientQueriesCount.MetricVec.GetMetricWithLabelValues(
|
||||
client.IP.String(),
|
||||
client.Mac,
|
||||
client.Hostname,
|
||||
)
|
||||
if err != nil {
|
||||
mainLog.Load().Debug().Err(err).Msgf("could not get metrics for client: %v", client)
|
||||
p.Debug().
|
||||
Err(err).
|
||||
Str("client_ip", client.IP.String()).
|
||||
Str("mac", client.Mac).
|
||||
Str("hostname", client.Hostname).
|
||||
Msg("Failed to get metrics for client")
|
||||
continue
|
||||
}
|
||||
if err := m.Write(dm); err == nil {
|
||||
|
||||
if err := m.Write(dm); err == nil && dm.Counter != nil {
|
||||
client.QueryCount = int64(dm.Counter.GetValue())
|
||||
p.Debug().
|
||||
Str("client_ip", client.IP.String()).
|
||||
Int64("query_count", client.QueryCount).
|
||||
Msg("Successfully collected query count")
|
||||
} else if err != nil {
|
||||
p.Debug().
|
||||
Err(err).
|
||||
Str("client_ip", client.IP.String()).
|
||||
Msg("Failed to write metric")
|
||||
}
|
||||
}
|
||||
} else {
|
||||
p.Debug().Msg("Metrics query stats disabled, skipping query counts")
|
||||
}
|
||||
|
||||
if err := json.NewEncoder(w).Encode(&clients); err != nil {
|
||||
p.Error().
|
||||
Err(err).
|
||||
Int("client_count", len(clients)).
|
||||
Msg("Failed to encode clients response")
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
p.Debug().
|
||||
Int("client_count", len(clients)).
|
||||
Msg("Successfully sent clients list response")
|
||||
}))
|
||||
p.cs.register(startedPath, http.HandlerFunc(func(w http.ResponseWriter, request *http.Request) {
|
||||
select {
|
||||
@@ -117,14 +177,14 @@ func (p *prog) registerControlServerHandler() {
|
||||
oldSvc := p.cfg.Service
|
||||
p.mu.Unlock()
|
||||
if err := p.sendReloadSignal(); err != nil {
|
||||
mainLog.Load().Err(err).Msg("could not send reload signal")
|
||||
p.Error().Err(err).Msg("Could not send reload signal")
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
select {
|
||||
case <-p.reloadDoneCh:
|
||||
case <-time.After(5 * time.Second):
|
||||
http.Error(w, "timeout waiting for ctrld reload", http.StatusInternalServerError)
|
||||
http.Error(w, "Timeout waiting for ctrld reload", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -152,8 +212,26 @@ func (p *prog) registerControlServerHandler() {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
p.cs.register(deactivationPath, http.HandlerFunc(func(w http.ResponseWriter, request *http.Request) {
|
||||
// Non-cd mode or pin code not set, always allowing deactivation.
|
||||
if cdUID == "" || deactivationPinNotSet() {
|
||||
// Non-cd mode always allowing deactivation.
|
||||
if cdUID == "" {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
return
|
||||
}
|
||||
|
||||
loggerCtx := ctrld.LoggerCtx(context.Background(), p.logger.Load())
|
||||
// Re-fetch pin code from API.
|
||||
if rc, err := controld.FetchResolverConfig(loggerCtx, cdUID, appVersion, cdDev); rc != nil {
|
||||
if rc.DeactivationPin != nil {
|
||||
cdDeactivationPin.Store(*rc.DeactivationPin)
|
||||
} else {
|
||||
cdDeactivationPin.Store(defaultDeactivationPin)
|
||||
}
|
||||
} else {
|
||||
p.Warn().Err(err).Msg("Could not re-fetch deactivation pin code")
|
||||
}
|
||||
|
||||
// If pin code not set, allowing deactivation.
|
||||
if !deactivationPinSet() {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
return
|
||||
}
|
||||
@@ -161,14 +239,18 @@ func (p *prog) registerControlServerHandler() {
|
||||
var req deactivationRequest
|
||||
if err := json.NewDecoder(request.Body).Decode(&req); err != nil {
|
||||
w.WriteHeader(http.StatusPreconditionFailed)
|
||||
mainLog.Load().Err(err).Msg("invalid deactivation request")
|
||||
p.Error().Err(err).Msg("Invalid deactivation request")
|
||||
return
|
||||
}
|
||||
|
||||
code := http.StatusForbidden
|
||||
switch req.Pin {
|
||||
case cdDeactivationPin:
|
||||
case cdDeactivationPin.Load():
|
||||
code = http.StatusOK
|
||||
select {
|
||||
case p.pinCodeValidCh <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
case defaultDeactivationPin:
|
||||
// If the pin code was set, but users do not provide --pin, return proper code to client.
|
||||
code = http.StatusBadRequest
|
||||
@@ -184,18 +266,81 @@ func (p *prog) registerControlServerHandler() {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
}))
|
||||
p.cs.register(ifacePath, http.HandlerFunc(func(w http.ResponseWriter, request *http.Request) {
|
||||
res := &ifaceResponse{Name: iface}
|
||||
// p.setDNS is only called when running as a service
|
||||
if !service.Interactive() {
|
||||
<-p.csSetDnsDone
|
||||
if p.csSetDnsOk {
|
||||
w.Write([]byte(iface))
|
||||
return
|
||||
res.Name = p.runningIface
|
||||
res.All = p.requiredMultiNICsConfig
|
||||
res.OK = true
|
||||
}
|
||||
}
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
if err := json.NewEncoder(w).Encode(res); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
http.Error(w, fmt.Sprintf("could not marshal iface data: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
}))
|
||||
p.cs.register(viewLogsPath, http.HandlerFunc(func(w http.ResponseWriter, request *http.Request) {
|
||||
lr, err := p.logReaderRaw()
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
defer lr.r.Close()
|
||||
if lr.size == 0 {
|
||||
w.WriteHeader(http.StatusMovedPermanently)
|
||||
return
|
||||
}
|
||||
data, err := io.ReadAll(lr.r)
|
||||
if err != nil {
|
||||
http.Error(w, fmt.Sprintf("could not read log: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
if err := json.NewEncoder(w).Encode(&logViewResponse{Data: string(data)}); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
http.Error(w, fmt.Sprintf("could not marshal log data: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
}))
|
||||
p.cs.register(sendLogsPath, http.HandlerFunc(func(w http.ResponseWriter, request *http.Request) {
|
||||
if time.Since(p.internalLogSent) < logWriterSentInterval {
|
||||
w.WriteHeader(http.StatusServiceUnavailable)
|
||||
return
|
||||
}
|
||||
r, err := p.logReaderNoColor()
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if r.size == 0 {
|
||||
w.WriteHeader(http.StatusMovedPermanently)
|
||||
return
|
||||
}
|
||||
req := &controld.LogsRequest{
|
||||
UID: cdUID,
|
||||
Data: r.r,
|
||||
}
|
||||
p.Debug().Msg("Sending log file to ControlD server")
|
||||
resp := logSentResponse{Size: r.size}
|
||||
loggerCtx := ctrld.LoggerCtx(context.Background(), p.logger.Load())
|
||||
if err := controld.SendLogs(loggerCtx, req, cdDev); err != nil {
|
||||
p.Error().Msgf("Could not send log file to ControlD server: %v", err)
|
||||
resp.Error = err.Error()
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
} else {
|
||||
p.Debug().Msg("Sending log file successfully")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}
|
||||
if err := json.NewEncoder(w).Encode(&resp); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
}
|
||||
p.internalLogSent = time.Now()
|
||||
}))
|
||||
}
|
||||
|
||||
// jsonResponse wraps an HTTP handler to set JSON content type
|
||||
func jsonResponse(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
@@ -1,4 +0,0 @@
|
||||
package cli
|
||||
|
||||
//lint:ignore U1000 use in os_linux.go
|
||||
type getDNS func(iface string) []string
|
||||
1679
cmd/cli/dns_proxy.go
1679
cmd/cli/dns_proxy.go
File diff suppressed because it is too large
Load Diff
@@ -30,6 +30,7 @@ func Test_wildcardMatches(t *testing.T) {
|
||||
{"domain - suffix not match other", "suffix.*", "suffix1.windscribe.com", false},
|
||||
{"domain - both", "suffix.*.windscribe.com", "suffix.anything.windscribe.com", true},
|
||||
{"domain - both not match", "suffix.*.windscribe.com", "suffix1.suffix.windscribe.com", false},
|
||||
{"domain - case-insensitive", "*.WINDSCRIBE.com", "anything.windscribe.com", true},
|
||||
{"mac - prefix", "*:98:05:b4:2b", "d4:67:98:05:b4:2b", true},
|
||||
{"mac - prefix not match other s", "*:98:05:b4:2b", "0d:ba:54:09:94:2c", false},
|
||||
{"mac - prefix not match s in name", "*:98:05:b4:2b", "e4:67:97:05:b4:2b", false},
|
||||
@@ -74,8 +75,10 @@ func Test_canonicalName(t *testing.T) {
|
||||
|
||||
func Test_prog_upstreamFor(t *testing.T) {
|
||||
cfg := testhelper.SampleConfig(t)
|
||||
cfg.Service.LeakOnUpstreamFailure = func(v bool) *bool { return &v }(false)
|
||||
p := &prog{cfg: cfg}
|
||||
p.um = newUpstreamMonitor(p.cfg)
|
||||
p.logger.Store(mainLog.Load())
|
||||
p.um = newUpstreamMonitor(p.cfg, mainLog.Load())
|
||||
p.lanLoopGuard = newLoopGuard()
|
||||
p.ptrLoopGuard = newLoopGuard()
|
||||
for _, nc := range p.cfg.Network {
|
||||
@@ -140,9 +143,94 @@ func Test_prog_upstreamFor(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func Test_prog_upstreamForWithCustomMatching(t *testing.T) {
|
||||
cfg := testhelper.SampleConfig(t)
|
||||
prog := &prog{cfg: cfg}
|
||||
prog.logger.Store(mainLog.Load())
|
||||
for _, nc := range prog.cfg.Network {
|
||||
for _, cidr := range nc.Cidrs {
|
||||
_, ipNet, err := net.ParseCIDR(cidr)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
nc.IPNets = append(nc.IPNets, ipNet)
|
||||
}
|
||||
}
|
||||
|
||||
// Create a custom policy with domain-first matching order
|
||||
customPolicy := &ctrld.ListenerPolicyConfig{
|
||||
Name: "Custom Policy",
|
||||
Networks: []ctrld.Rule{
|
||||
{"network.0": []string{"upstream.1", "upstream.0"}},
|
||||
},
|
||||
Macs: []ctrld.Rule{
|
||||
{"14:45:A0:67:83:0A": []string{"upstream.2"}},
|
||||
},
|
||||
Rules: []ctrld.Rule{
|
||||
{"*.ru": []string{"upstream.1"}},
|
||||
},
|
||||
Matching: &ctrld.MatchingConfig{
|
||||
Order: []string{"domain", "mac", "network"},
|
||||
},
|
||||
}
|
||||
|
||||
customListener := &ctrld.ListenerConfig{
|
||||
Policy: customPolicy,
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
ip string
|
||||
mac string
|
||||
domain string
|
||||
upstreams []string
|
||||
matched bool
|
||||
}{
|
||||
{
|
||||
name: "Domain rule should match first with custom order",
|
||||
ip: "192.168.0.1:0",
|
||||
mac: "14:45:A0:67:83:0A",
|
||||
domain: "example.ru",
|
||||
upstreams: []string{"upstream.1"},
|
||||
matched: true,
|
||||
},
|
||||
{
|
||||
name: "MAC rule should match when no domain rule",
|
||||
ip: "192.168.0.1:0",
|
||||
mac: "14:45:A0:67:83:0A",
|
||||
domain: "example.com",
|
||||
upstreams: []string{"upstream.2"},
|
||||
matched: true,
|
||||
},
|
||||
{
|
||||
name: "Network rule should match when no domain or MAC rule",
|
||||
ip: "192.168.0.1:0",
|
||||
mac: "00:11:22:33:44:55",
|
||||
domain: "example.com",
|
||||
upstreams: []string{"upstream.1", "upstream.0"},
|
||||
matched: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
addr, err := net.ResolveUDPAddr("udp", tc.ip)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, addr)
|
||||
|
||||
ctx := context.WithValue(context.Background(), ctrld.ReqIdCtxKey{}, requestID())
|
||||
ufr := prog.upstreamFor(ctx, "0", customListener, addr, tc.mac, tc.domain)
|
||||
|
||||
assert.Equal(t, tc.matched, ufr.matched)
|
||||
assert.Equal(t, tc.upstreams, ufr.upstreams)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCache(t *testing.T) {
|
||||
cfg := testhelper.SampleConfig(t)
|
||||
prog := &prog{cfg: cfg}
|
||||
prog.logger.Store(mainLog.Load())
|
||||
for _, nc := range prog.cfg.Network {
|
||||
for _, cidr := range nc.Cidrs {
|
||||
_, ipNet, err := net.ParseCIDR(cidr)
|
||||
@@ -364,6 +452,9 @@ func Test_isLanHostnameQuery(t *testing.T) {
|
||||
{"A not LAN", newDnsMsgWithHostname("example.com", dns.TypeA), false},
|
||||
{"AAAA not LAN", newDnsMsgWithHostname("example.com", dns.TypeAAAA), false},
|
||||
{"Not A or AAAA", newDnsMsgWithHostname("foo", dns.TypeTXT), false},
|
||||
{".domain", newDnsMsgWithHostname("foo.domain", dns.TypeA), true},
|
||||
{".lan", newDnsMsgWithHostname("foo.lan", dns.TypeA), true},
|
||||
{".local", newDnsMsgWithHostname("foo.local", dns.TypeA), true},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
@@ -413,6 +504,27 @@ func Test_isPrivatePtrLookup(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func Test_isSrvLanLookup(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
msg *dns.Msg
|
||||
isSrvLookup bool
|
||||
}{
|
||||
{"SRV LAN", newDnsMsgWithHostname("foo", dns.TypeSRV), true},
|
||||
{"Not SRV", newDnsMsgWithHostname("foo", dns.TypeNone), false},
|
||||
{"Not SRV LAN", newDnsMsgWithHostname("controld.com", dns.TypeSRV), false},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
if got := isSrvLanLookup(tc.msg); tc.isSrvLookup != got {
|
||||
t.Errorf("unexpected result, want: %v, got: %v", tc.isSrvLookup, got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_isWanClient(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -438,3 +550,254 @@ func Test_isWanClient(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_shouldStartRecovery(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
reason RecoveryReason
|
||||
hasExistingRecovery bool
|
||||
expectedResult bool
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "network change with existing recovery",
|
||||
reason: RecoveryReasonNetworkChange,
|
||||
hasExistingRecovery: true,
|
||||
expectedResult: true,
|
||||
description: "should cancel existing recovery and start new one for network change",
|
||||
},
|
||||
{
|
||||
name: "network change without existing recovery",
|
||||
reason: RecoveryReasonNetworkChange,
|
||||
hasExistingRecovery: false,
|
||||
expectedResult: true,
|
||||
description: "should start new recovery for network change",
|
||||
},
|
||||
{
|
||||
name: "regular failure with existing recovery",
|
||||
reason: RecoveryReasonRegularFailure,
|
||||
hasExistingRecovery: true,
|
||||
expectedResult: false,
|
||||
description: "should skip duplicate recovery for regular failure",
|
||||
},
|
||||
{
|
||||
name: "regular failure without existing recovery",
|
||||
reason: RecoveryReasonRegularFailure,
|
||||
hasExistingRecovery: false,
|
||||
expectedResult: true,
|
||||
description: "should start new recovery for regular failure",
|
||||
},
|
||||
{
|
||||
name: "OS failure with existing recovery",
|
||||
reason: RecoveryReasonOSFailure,
|
||||
hasExistingRecovery: true,
|
||||
expectedResult: false,
|
||||
description: "should skip duplicate recovery for OS failure",
|
||||
},
|
||||
{
|
||||
name: "OS failure without existing recovery",
|
||||
reason: RecoveryReasonOSFailure,
|
||||
hasExistingRecovery: false,
|
||||
expectedResult: true,
|
||||
description: "should start new recovery for OS failure",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
p := newTestProg(t)
|
||||
|
||||
// Setup existing recovery if needed
|
||||
if tc.hasExistingRecovery {
|
||||
p.recoveryCancelMu.Lock()
|
||||
p.recoveryCancel = func() {} // Mock cancel function
|
||||
p.recoveryCancelMu.Unlock()
|
||||
}
|
||||
|
||||
result := p.shouldStartRecovery(tc.reason)
|
||||
assert.Equal(t, tc.expectedResult, result, tc.description)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_createRecoveryContext(t *testing.T) {
|
||||
p := newTestProg(t)
|
||||
|
||||
ctx, cleanup := p.createRecoveryContext()
|
||||
|
||||
// Verify context is created
|
||||
assert.NotNil(t, ctx)
|
||||
assert.NotNil(t, cleanup)
|
||||
|
||||
// Verify recoveryCancel is set
|
||||
p.recoveryCancelMu.Lock()
|
||||
assert.NotNil(t, p.recoveryCancel)
|
||||
p.recoveryCancelMu.Unlock()
|
||||
|
||||
// Test cleanup function
|
||||
cleanup()
|
||||
|
||||
// Verify recoveryCancel is cleared
|
||||
p.recoveryCancelMu.Lock()
|
||||
assert.Nil(t, p.recoveryCancel)
|
||||
p.recoveryCancelMu.Unlock()
|
||||
}
|
||||
|
||||
func Test_prepareForRecovery(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
reason RecoveryReason
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "regular failure",
|
||||
reason: RecoveryReasonRegularFailure,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "network change",
|
||||
reason: RecoveryReasonNetworkChange,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "OS failure",
|
||||
reason: RecoveryReasonOSFailure,
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
p := newTestProg(t)
|
||||
|
||||
err := p.prepareForRecovery(tc.reason)
|
||||
|
||||
if tc.wantErr {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
// Verify recoveryRunning is set to true
|
||||
assert.True(t, p.recoveryRunning.Load())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_completeRecovery(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
reason RecoveryReason
|
||||
recovered string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "regular failure recovery",
|
||||
reason: RecoveryReasonRegularFailure,
|
||||
recovered: "upstream1",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "network change recovery",
|
||||
reason: RecoveryReasonNetworkChange,
|
||||
recovered: "upstream2",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "OS failure recovery",
|
||||
reason: RecoveryReasonOSFailure,
|
||||
recovered: "upstream3",
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
p := newTestProg(t)
|
||||
|
||||
err := p.completeRecovery(tc.reason, tc.recovered)
|
||||
|
||||
if tc.wantErr {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
// Verify recoveryRunning is set to false
|
||||
assert.False(t, p.recoveryRunning.Load())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_reinitializeOSResolver(t *testing.T) {
|
||||
p := newTestProg(t)
|
||||
|
||||
err := p.reinitializeOSResolver("Test message")
|
||||
|
||||
// This function should not return an error under normal circumstances
|
||||
// The actual behavior depends on the OS resolver implementation
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func Test_handleRecovery_Integration(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
reason RecoveryReason
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "network change recovery",
|
||||
reason: RecoveryReasonNetworkChange,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "regular failure recovery",
|
||||
reason: RecoveryReasonRegularFailure,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "OS failure recovery",
|
||||
reason: RecoveryReasonOSFailure,
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
p := newTestProg(t)
|
||||
|
||||
// This is an integration test that exercises the full recovery flow
|
||||
// In a real test environment, you would mock the dependencies
|
||||
// For now, we're just testing that the method doesn't panic
|
||||
// and that the recovery logic flows correctly
|
||||
assert.NotPanics(t, func() {
|
||||
// Test only the preparation phase to avoid actual upstream checking
|
||||
if !p.shouldStartRecovery(tc.reason) {
|
||||
return
|
||||
}
|
||||
|
||||
_, cleanup := p.createRecoveryContext()
|
||||
defer cleanup()
|
||||
|
||||
if err := p.prepareForRecovery(tc.reason); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Skip the actual upstream recovery check for this test
|
||||
// as it requires properly configured upstreams
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// newTestProg creates a properly initialized *prog for testing.
|
||||
func newTestProg(t *testing.T) *prog {
|
||||
p := &prog{cfg: testhelper.SampleConfig(t)}
|
||||
p.logger.Store(mainLog.Load())
|
||||
p.um = newUpstreamMonitor(p.cfg, mainLog.Load())
|
||||
return p
|
||||
}
|
||||
|
||||
18
cmd/cli/hostname.go
Normal file
18
cmd/cli/hostname.go
Normal file
@@ -0,0 +1,18 @@
|
||||
package cli
|
||||
|
||||
import "regexp"
|
||||
|
||||
// validHostname reports whether hostname is a valid hostname.
|
||||
// A valid hostname contains 3 -> 64 characters and conform to RFC1123.
|
||||
// This function validates hostnames to ensure they meet DNS naming standards
|
||||
// and prevents invalid hostnames from being used in DNS configurations
|
||||
func validHostname(hostname string) bool {
|
||||
hostnameLen := len(hostname)
|
||||
if hostnameLen < 3 || hostnameLen > 64 {
|
||||
return false
|
||||
}
|
||||
// RFC1123 regex pattern ensures hostnames follow DNS naming conventions
|
||||
// This prevents issues with DNS resolution and system compatibility
|
||||
validHostnameRfc1123 := regexp.MustCompile(`^(([a-zA-Z0-9]|[a-zA-Z0-9][a-zA-Z0-9\-]*[a-zA-Z0-9])\.)*([A-Za-z0-9]|[A-Za-z0-9][A-Za-z0-9\-]*[A-Za-z0-9])$`)
|
||||
return validHostnameRfc1123.MatchString(hostname)
|
||||
}
|
||||
35
cmd/cli/hostname_test.go
Normal file
35
cmd/cli/hostname_test.go
Normal file
@@ -0,0 +1,35 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func Test_validHostname(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
hostname string
|
||||
valid bool
|
||||
}{
|
||||
{"localhost", "localhost", true},
|
||||
{"localdomain", "localhost.localdomain", true},
|
||||
{"localhost6", "localhost6.localdomain6", true},
|
||||
{"ip6", "ip6-localhost", true},
|
||||
{"non-domain", "controld", true},
|
||||
{"domain", "controld.com", true},
|
||||
{"empty", "", false},
|
||||
{"min length", "fo", false},
|
||||
{"max length", strings.Repeat("a", 65), false},
|
||||
{"special char", "foo!", false},
|
||||
{"non-ascii", "fooΩ", false},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.hostname, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
assert.True(t, validHostname(tc.hostname) == tc.valid)
|
||||
})
|
||||
}
|
||||
}
|
||||
172
cmd/cli/http_log.go
Normal file
172
cmd/cli/http_log.go
Normal file
@@ -0,0 +1,172 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// HTTP log server endpoint constants
|
||||
const (
|
||||
httpLogEndpointPing = "/ping"
|
||||
httpLogEndpointLogs = "/logs"
|
||||
httpLogEndpointExit = "/exit"
|
||||
)
|
||||
|
||||
// httpLogClient sends logs to an HTTP server via POST requests.
|
||||
// This replaces the logConn functionality with HTTP-based communication.
|
||||
type httpLogClient struct {
|
||||
baseURL string
|
||||
client *http.Client
|
||||
}
|
||||
|
||||
// newHTTPLogClient creates a new HTTP log client
|
||||
func newHTTPLogClient(sockPath string) *httpLogClient {
|
||||
return &httpLogClient{
|
||||
baseURL: "http://unix",
|
||||
client: &http.Client{
|
||||
Transport: &http.Transport{
|
||||
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
return net.Dial("unix", sockPath)
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Write sends log data to the HTTP server via POST request
|
||||
func (hlc *httpLogClient) Write(b []byte) (int, error) {
|
||||
// Send log data via HTTP POST to /logs endpoint
|
||||
resp, err := hlc.client.Post(hlc.baseURL+httpLogEndpointLogs, "text/plain", bytes.NewReader(b))
|
||||
if err != nil {
|
||||
// Ignore errors to prevent log pollution, just like the original logConn
|
||||
return len(b), nil
|
||||
}
|
||||
resp.Body.Close()
|
||||
return len(b), nil
|
||||
}
|
||||
|
||||
// Ping tests if the HTTP log server is available
|
||||
func (hlc *httpLogClient) Ping() error {
|
||||
resp, err := hlc.client.Get(hlc.baseURL + httpLogEndpointPing)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
resp.Body.Close()
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close sends exit signal to the HTTP server
|
||||
func (hlc *httpLogClient) Close() error {
|
||||
// Send exit signal via HTTP POST with empty body
|
||||
resp, err := hlc.client.Post(hlc.baseURL+httpLogEndpointExit, "text/plain", bytes.NewReader([]byte{}))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
resp.Body.Close()
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetLogs retrieves all collected logs from the HTTP server
|
||||
func (hlc *httpLogClient) GetLogs() ([]byte, error) {
|
||||
resp, err := hlc.client.Get(hlc.baseURL + httpLogEndpointLogs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode == http.StatusNoContent {
|
||||
return []byte{}, nil
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("unexpected status code: %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
return io.ReadAll(resp.Body)
|
||||
}
|
||||
|
||||
// httpLogServer starts an HTTP server listening on unix socket to collect logs from runCmd.
|
||||
func httpLogServer(sockPath string, stopLogCh chan struct{}) error {
|
||||
addr, err := net.ResolveUnixAddr("unix", sockPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid log sock path: %w", err)
|
||||
}
|
||||
|
||||
ln, err := net.ListenUnix("unix", addr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not listen log socket: %w", err)
|
||||
}
|
||||
defer ln.Close()
|
||||
|
||||
// Create a log writer to store all logs
|
||||
logWriter := newLogWriter()
|
||||
|
||||
// Use a sync.Once to ensure channel is only closed once
|
||||
var channelClosed sync.Once
|
||||
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc(httpLogEndpointPing, func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodGet {
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
mux.HandleFunc(httpLogEndpointLogs, func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.Method {
|
||||
case http.MethodPost:
|
||||
// POST /logs - Store log data
|
||||
body, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
http.Error(w, "Failed to read request body", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Store log data in log writer
|
||||
logWriter.Write(body)
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
|
||||
case http.MethodGet:
|
||||
// GET /logs - Retrieve all logs
|
||||
// Get all logs from the log writer
|
||||
logWriter.mu.Lock()
|
||||
logs := logWriter.buf.Bytes()
|
||||
logWriter.mu.Unlock()
|
||||
|
||||
if len(logs) == 0 {
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "text/plain")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write(logs)
|
||||
|
||||
default:
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
}
|
||||
})
|
||||
|
||||
mux.HandleFunc(httpLogEndpointExit, func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
// Close the stop channel to signal completion (only once)
|
||||
channelClosed.Do(func() {
|
||||
close(stopLogCh)
|
||||
})
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
server := &http.Server{Handler: mux}
|
||||
return server.Serve(ln)
|
||||
}
|
||||
747
cmd/cli/http_log_test.go
Normal file
747
cmd/cli/http_log_test.go
Normal file
@@ -0,0 +1,747 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"golang.org/x/net/nettest"
|
||||
)
|
||||
|
||||
func unixDomainSocketPath(t *testing.T) string {
|
||||
t.Helper()
|
||||
sockPath, err := nettest.LocalPath()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temporary directory: %v", err)
|
||||
}
|
||||
return sockPath
|
||||
}
|
||||
|
||||
func TestHTTPLogServer(t *testing.T) {
|
||||
sockPath := unixDomainSocketPath(t)
|
||||
|
||||
// Create log channel
|
||||
stopLogCh := make(chan struct{})
|
||||
|
||||
// Start HTTP log server in a goroutine
|
||||
serverErr := make(chan error, 1)
|
||||
go func() {
|
||||
serverErr <- httpLogServer(sockPath, stopLogCh)
|
||||
}()
|
||||
|
||||
// Wait a bit for server to start
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Create HTTP client
|
||||
client := &http.Client{
|
||||
Transport: &http.Transport{
|
||||
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
return net.Dial("unix", sockPath)
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
t.Run("Ping endpoint", func(t *testing.T) {
|
||||
resp, err := client.Get("http://unix" + httpLogEndpointPing)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to ping server: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Errorf("Expected status 200, got %d", resp.StatusCode)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Ping endpoint wrong method", func(t *testing.T) {
|
||||
resp, err := client.Post("http://unix"+httpLogEndpointPing, "text/plain", bytes.NewReader([]byte("test")))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to send POST to ping: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusMethodNotAllowed {
|
||||
t.Errorf("Expected status 405, got %d", resp.StatusCode)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Log endpoint", func(t *testing.T) {
|
||||
testLog := "test log message"
|
||||
resp, err := client.Post("http://unix"+httpLogEndpointLogs, "text/plain", bytes.NewReader([]byte(testLog)))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to send log: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Errorf("Expected status 200, got %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
// Check if log was stored by retrieving it
|
||||
logsResp, err := client.Get("http://unix" + httpLogEndpointLogs)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get logs: %v", err)
|
||||
}
|
||||
defer logsResp.Body.Close()
|
||||
|
||||
if logsResp.StatusCode != http.StatusOK {
|
||||
t.Errorf("Expected status 200 for logs, got %d", logsResp.StatusCode)
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(logsResp.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to read logs: %v", err)
|
||||
}
|
||||
|
||||
if !strings.Contains(string(body), testLog) {
|
||||
t.Errorf("Expected log '%s' not found in stored logs", testLog)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Log endpoint wrong method", func(t *testing.T) {
|
||||
// Test unsupported method (PUT) on /logs endpoint
|
||||
req, err := http.NewRequest("PUT", "http://unix"+httpLogEndpointLogs, bytes.NewReader([]byte("test")))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create PUT request: %v", err)
|
||||
}
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to send PUT to logs: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusMethodNotAllowed {
|
||||
t.Errorf("Expected status 405, got %d", resp.StatusCode)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Exit endpoint", func(t *testing.T) {
|
||||
resp, err := client.Post("http://unix"+httpLogEndpointExit, "text/plain", bytes.NewReader([]byte{}))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to send exit: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Errorf("Expected status 200, got %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
// Check if channel is closed by trying to read from it
|
||||
select {
|
||||
case _, ok := <-stopLogCh:
|
||||
if ok {
|
||||
t.Error("Expected channel to be closed, but it's still open")
|
||||
}
|
||||
case <-time.After(1 * time.Second):
|
||||
t.Error("Timeout waiting for channel closure")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Exit endpoint wrong method", func(t *testing.T) {
|
||||
resp, err := client.Get("http://unix" + httpLogEndpointExit)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to send GET to exit: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusMethodNotAllowed {
|
||||
t.Errorf("Expected status 405, got %d", resp.StatusCode)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Multiple log messages", func(t *testing.T) {
|
||||
logs := []string{"log1", "log2", "log3"}
|
||||
|
||||
for _, log := range logs {
|
||||
resp, err := client.Post("http://unix"+httpLogEndpointLogs, "text/plain", bytes.NewReader([]byte(log+"\n")))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to send log '%s': %v", log, err)
|
||||
}
|
||||
resp.Body.Close()
|
||||
}
|
||||
|
||||
// Check if all logs were stored by retrieving them
|
||||
logsResp, err := client.Get("http://unix" + httpLogEndpointLogs)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get logs: %v", err)
|
||||
}
|
||||
defer logsResp.Body.Close()
|
||||
|
||||
if logsResp.StatusCode != http.StatusOK {
|
||||
t.Errorf("Expected status 200 for logs, got %d", logsResp.StatusCode)
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(logsResp.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to read logs: %v", err)
|
||||
}
|
||||
|
||||
logContent := string(body)
|
||||
for i, expectedLog := range logs {
|
||||
if !strings.Contains(logContent, expectedLog) {
|
||||
t.Errorf("Log %d: expected '%s' not found in stored logs", i, expectedLog)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Large log message", func(t *testing.T) {
|
||||
largeLog := strings.Repeat("a", 1024*10) // 10KB log message
|
||||
resp, err := client.Post("http://unix"+httpLogEndpointLogs, "text/plain", bytes.NewReader([]byte(largeLog)))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to send large log: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Errorf("Expected status 200, got %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
// Check if large log was stored by retrieving it
|
||||
logsResp, err := client.Get("http://unix" + httpLogEndpointLogs)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get logs: %v", err)
|
||||
}
|
||||
defer logsResp.Body.Close()
|
||||
|
||||
if logsResp.StatusCode != http.StatusOK {
|
||||
t.Errorf("Expected status 200 for logs, got %d", logsResp.StatusCode)
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(logsResp.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to read logs: %v", err)
|
||||
}
|
||||
|
||||
if !strings.Contains(string(body), largeLog) {
|
||||
t.Error("Large log message was not stored correctly")
|
||||
}
|
||||
})
|
||||
|
||||
// Clean up
|
||||
os.Remove(sockPath)
|
||||
}
|
||||
|
||||
func TestHTTPLogServerInvalidSocketPath(t *testing.T) {
|
||||
// Test with invalid socket path
|
||||
invalidPath := "/invalid/path/that/does/not/exist.sock"
|
||||
stopLogCh := make(chan struct{})
|
||||
|
||||
err := httpLogServer(invalidPath, stopLogCh)
|
||||
if err == nil {
|
||||
t.Error("Expected error for invalid socket path")
|
||||
}
|
||||
|
||||
if !strings.Contains(err.Error(), "could not listen log socket") {
|
||||
t.Errorf("Expected 'could not listen log socket' error, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHTTPLogServerSocketInUse(t *testing.T) {
|
||||
// Create a temporary socket path
|
||||
sockPath := unixDomainSocketPath(t)
|
||||
defer os.Remove(sockPath)
|
||||
|
||||
// Create the first server
|
||||
stopLogCh1 := make(chan struct{})
|
||||
serverErr1 := make(chan error, 1)
|
||||
go func() {
|
||||
serverErr1 <- httpLogServer(sockPath, stopLogCh1)
|
||||
}()
|
||||
|
||||
// Wait for first server to start
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Try to create a second server on the same socket
|
||||
stopLogCh2 := make(chan struct{})
|
||||
err := httpLogServer(sockPath, stopLogCh2)
|
||||
if err == nil {
|
||||
t.Error("Expected error when socket is already in use")
|
||||
}
|
||||
|
||||
if !strings.Contains(err.Error(), "could not listen log socket") {
|
||||
t.Errorf("Expected 'could not listen log socket' error, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHTTPLogServerConcurrentRequests(t *testing.T) {
|
||||
// Create a temporary socket path
|
||||
sockPath := unixDomainSocketPath(t)
|
||||
defer os.Remove(sockPath)
|
||||
|
||||
// Create log channel
|
||||
stopLogCh := make(chan struct{})
|
||||
|
||||
// Start HTTP log server in a goroutine
|
||||
serverErr := make(chan error, 1)
|
||||
go func() {
|
||||
serverErr <- httpLogServer(sockPath, stopLogCh)
|
||||
}()
|
||||
|
||||
// Wait for server to start
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Create HTTP client
|
||||
client := &http.Client{
|
||||
Transport: &http.Transport{
|
||||
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
return net.Dial("unix", sockPath)
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Send concurrent requests
|
||||
numRequests := 10
|
||||
done := make(chan bool, numRequests)
|
||||
|
||||
for i := 0; i < numRequests; i++ {
|
||||
go func(i int) {
|
||||
defer func() { done <- true }()
|
||||
|
||||
logMsg := fmt.Sprintf("concurrent log %d", i)
|
||||
resp, err := client.Post("http://unix"+httpLogEndpointLogs, "text/plain", bytes.NewReader([]byte(logMsg)))
|
||||
if err != nil {
|
||||
t.Errorf("Failed to send concurrent log %d: %v", i, err)
|
||||
return
|
||||
}
|
||||
resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Errorf("Expected status 200 for request %d, got %d", i, resp.StatusCode)
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
// Wait for all requests to complete
|
||||
for i := 0; i < numRequests; i++ {
|
||||
select {
|
||||
case <-done:
|
||||
// Request completed
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Errorf("Timeout waiting for concurrent request %d", i)
|
||||
}
|
||||
}
|
||||
|
||||
// Check if all logs were stored by retrieving them
|
||||
logsResp, err := client.Get("http://unix" + httpLogEndpointLogs)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get logs: %v", err)
|
||||
}
|
||||
defer logsResp.Body.Close()
|
||||
|
||||
if logsResp.StatusCode != http.StatusOK {
|
||||
t.Errorf("Expected status 200 for logs, got %d", logsResp.StatusCode)
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(logsResp.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to read logs: %v", err)
|
||||
}
|
||||
|
||||
logContent := string(body)
|
||||
// Verify all logs were stored
|
||||
for i := 0; i < numRequests; i++ {
|
||||
expectedLog := fmt.Sprintf("concurrent log %d", i)
|
||||
if !strings.Contains(logContent, expectedLog) {
|
||||
t.Errorf("Log '%s' was not stored", expectedLog)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestHTTPLogServerErrorHandling(t *testing.T) {
|
||||
// Create a temporary socket path
|
||||
sockPath := unixDomainSocketPath(t)
|
||||
defer os.Remove(sockPath)
|
||||
|
||||
// Create log channel
|
||||
stopLogCh := make(chan struct{})
|
||||
|
||||
// Start HTTP log server in a goroutine
|
||||
serverErr := make(chan error, 1)
|
||||
go func() {
|
||||
serverErr <- httpLogServer(sockPath, stopLogCh)
|
||||
}()
|
||||
|
||||
// Wait for server to start
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Create HTTP client
|
||||
client := &http.Client{
|
||||
Transport: &http.Transport{
|
||||
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
return net.Dial("unix", sockPath)
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
t.Run("Invalid request body", func(t *testing.T) {
|
||||
// Test with malformed request - this will fail at HTTP level, not server level
|
||||
// The server will return 400 Bad Request for invalid body
|
||||
resp, err := client.Post("http://unix"+httpLogEndpointLogs, "text/plain", strings.NewReader(""))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to send request: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// Empty body should still be processed successfully
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Errorf("Expected status 200, got %d", resp.StatusCode)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func BenchmarkHTTPLogServer(b *testing.B) {
|
||||
// Create a temporary socket path
|
||||
tmpDir := b.TempDir()
|
||||
sockPath := filepath.Join(tmpDir, "bench.sock")
|
||||
|
||||
// Create log channel
|
||||
stopLogCh := make(chan struct{})
|
||||
|
||||
// Start HTTP log server in a goroutine
|
||||
go func() {
|
||||
httpLogServer(sockPath, stopLogCh)
|
||||
}()
|
||||
|
||||
// Wait for server to start
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Create HTTP client
|
||||
client := &http.Client{
|
||||
Transport: &http.Transport{
|
||||
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
return net.Dial("unix", sockPath)
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Benchmark log sending
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
logMsg := fmt.Sprintf("benchmark log %d", i)
|
||||
resp, err := client.Post("http://unix"+httpLogEndpointLogs, "text/plain", bytes.NewReader([]byte(logMsg)))
|
||||
if err != nil {
|
||||
b.Fatalf("Failed to send log: %v", err)
|
||||
}
|
||||
resp.Body.Close()
|
||||
}
|
||||
|
||||
// Clean up
|
||||
os.Remove(sockPath)
|
||||
}
|
||||
|
||||
func TestHTTPLogClient(t *testing.T) {
|
||||
// Create a temporary socket path
|
||||
sockPath := unixDomainSocketPath(t)
|
||||
defer os.Remove(sockPath)
|
||||
|
||||
// Create log channel
|
||||
stopLogCh := make(chan struct{})
|
||||
|
||||
// Start HTTP log server in a goroutine
|
||||
serverErr := make(chan error, 1)
|
||||
go func() {
|
||||
serverErr <- httpLogServer(sockPath, stopLogCh)
|
||||
}()
|
||||
|
||||
// Wait for server to start
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Create HTTP log client
|
||||
client := newHTTPLogClient(sockPath)
|
||||
|
||||
t.Run("Ping server", func(t *testing.T) {
|
||||
err := client.Ping()
|
||||
if err != nil {
|
||||
t.Errorf("Ping failed: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Write logs", func(t *testing.T) {
|
||||
testLog := "test log message from client"
|
||||
n, err := client.Write([]byte(testLog))
|
||||
if err != nil {
|
||||
t.Errorf("Write failed: %v", err)
|
||||
}
|
||||
if n != len(testLog) {
|
||||
t.Errorf("Expected to write %d bytes, wrote %d", len(testLog), n)
|
||||
}
|
||||
|
||||
// Check if log was stored by retrieving it
|
||||
logs, err := client.GetLogs()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get logs: %v", err)
|
||||
}
|
||||
|
||||
if !strings.Contains(string(logs), testLog) {
|
||||
t.Errorf("Expected log '%s' not found in stored logs", testLog)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Close client", func(t *testing.T) {
|
||||
err := client.Close()
|
||||
if err != nil {
|
||||
t.Errorf("Close failed: %v", err)
|
||||
}
|
||||
|
||||
// Check if channel is closed (signaling completion)
|
||||
select {
|
||||
case _, ok := <-stopLogCh:
|
||||
if ok {
|
||||
t.Error("Expected channel to be closed, but it's still open")
|
||||
}
|
||||
case <-time.After(1 * time.Second):
|
||||
t.Error("Timeout waiting for channel closure")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestHTTPLogClientServerUnavailable(t *testing.T) {
|
||||
// Create client with non-existent socket
|
||||
sockPath := "/non/existent/socket.sock"
|
||||
client := newHTTPLogClient(sockPath)
|
||||
|
||||
t.Run("Ping unavailable server", func(t *testing.T) {
|
||||
err := client.Ping()
|
||||
if err == nil {
|
||||
t.Error("Expected ping to fail for unavailable server")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Write to unavailable server", func(t *testing.T) {
|
||||
testLog := "test log message"
|
||||
n, err := client.Write([]byte(testLog))
|
||||
if err != nil {
|
||||
t.Errorf("Write should not return error (ignores errors): %v", err)
|
||||
}
|
||||
if n != len(testLog) {
|
||||
t.Errorf("Expected to write %d bytes, wrote %d", len(testLog), n)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Close unavailable server", func(t *testing.T) {
|
||||
err := client.Close()
|
||||
if err == nil {
|
||||
t.Error("Expected close to fail for unavailable server")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func BenchmarkHTTPLogClient(b *testing.B) {
|
||||
// Create a temporary socket path
|
||||
tmpDir := b.TempDir()
|
||||
sockPath := filepath.Join(tmpDir, "bench.sock")
|
||||
|
||||
// Create log channel
|
||||
stopLogCh := make(chan struct{})
|
||||
|
||||
// Start HTTP log server in a goroutine
|
||||
go func() {
|
||||
httpLogServer(sockPath, stopLogCh)
|
||||
}()
|
||||
|
||||
// Wait for server to start
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Create HTTP log client
|
||||
client := newHTTPLogClient(sockPath)
|
||||
|
||||
// Benchmark client writes
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
logMsg := fmt.Sprintf("benchmark write %d", i)
|
||||
client.Write([]byte(logMsg))
|
||||
}
|
||||
|
||||
// Clean up
|
||||
os.Remove(sockPath)
|
||||
}
|
||||
|
||||
func TestHTTPLogServerWithLogWriter(t *testing.T) {
|
||||
// Create a temporary socket path
|
||||
sockPath := unixDomainSocketPath(t)
|
||||
defer os.Remove(sockPath)
|
||||
|
||||
// Create log channel
|
||||
stopLogCh := make(chan struct{})
|
||||
|
||||
// Start HTTP log server in a goroutine
|
||||
serverErr := make(chan error, 1)
|
||||
go func() {
|
||||
serverErr <- httpLogServer(sockPath, stopLogCh)
|
||||
}()
|
||||
|
||||
// Wait a bit for server to start
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Create HTTP client
|
||||
client := &http.Client{
|
||||
Transport: &http.Transport{
|
||||
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
return net.Dial("unix", sockPath)
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
t.Run("Store and retrieve logs", func(t *testing.T) {
|
||||
// Send multiple log messages
|
||||
logs := []string{"log message 1", "log message 2", "log message 3"}
|
||||
|
||||
for _, log := range logs {
|
||||
resp, err := client.Post("http://unix"+httpLogEndpointLogs, "text/plain", bytes.NewReader([]byte(log+"\n")))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to send log '%s': %v", log, err)
|
||||
}
|
||||
resp.Body.Close()
|
||||
}
|
||||
|
||||
// Retrieve all logs
|
||||
resp, err := client.Get("http://unix" + httpLogEndpointLogs)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get logs: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Errorf("Expected status 200, got %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to read logs response: %v", err)
|
||||
}
|
||||
|
||||
logContent := string(body)
|
||||
for _, log := range logs {
|
||||
if !strings.Contains(logContent, log) {
|
||||
t.Errorf("Expected log '%s' not found in retrieved logs", log)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Empty logs endpoint", func(t *testing.T) {
|
||||
// Create a new server for this test
|
||||
sockPath2 := unixDomainSocketPath(t)
|
||||
stopLogCh2 := make(chan struct{})
|
||||
|
||||
go func() {
|
||||
httpLogServer(sockPath2, stopLogCh2)
|
||||
}()
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
client2 := &http.Client{
|
||||
Transport: &http.Transport{
|
||||
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
return net.Dial("unix", sockPath2)
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
resp, err := client2.Get("http://unix" + httpLogEndpointLogs)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get logs: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusNoContent {
|
||||
t.Errorf("Expected status 204, got %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
os.Remove(sockPath2)
|
||||
})
|
||||
|
||||
t.Run("Channel closure on exit", func(t *testing.T) {
|
||||
// Send exit signal
|
||||
resp, err := client.Post("http://unix"+httpLogEndpointExit, "text/plain", bytes.NewReader([]byte{}))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to send exit: %v", err)
|
||||
}
|
||||
resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Errorf("Expected status 200, got %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
// Check if channel is closed by trying to read from it
|
||||
select {
|
||||
case _, ok := <-stopLogCh:
|
||||
if ok {
|
||||
t.Error("Expected channel to be closed, but it's still open")
|
||||
}
|
||||
case <-time.After(1 * time.Second):
|
||||
t.Error("Timeout waiting for channel closure")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestHTTPLogClientGetLogs(t *testing.T) {
|
||||
// Create a temporary socket path
|
||||
sockPath := unixDomainSocketPath(t)
|
||||
defer os.Remove(sockPath)
|
||||
|
||||
// Create log channel
|
||||
stopLogCh := make(chan struct{})
|
||||
|
||||
// Start HTTP log server in a goroutine
|
||||
go func() {
|
||||
httpLogServer(sockPath, stopLogCh)
|
||||
}()
|
||||
|
||||
// Wait a bit for server to start
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Create HTTP log client
|
||||
client := newHTTPLogClient(sockPath)
|
||||
|
||||
t.Run("Get logs from client", func(t *testing.T) {
|
||||
// Send some logs
|
||||
testLogs := []string{"client log 1", "client log 2", "client log 3"}
|
||||
for _, log := range testLogs {
|
||||
client.Write([]byte(log + "\n"))
|
||||
}
|
||||
|
||||
// Retrieve logs using client method
|
||||
logs, err := client.GetLogs()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get logs: %v", err)
|
||||
}
|
||||
|
||||
logContent := string(logs)
|
||||
for _, log := range testLogs {
|
||||
if !strings.Contains(logContent, log) {
|
||||
t.Errorf("Expected log '%s' not found in retrieved logs", log)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Get empty logs", func(t *testing.T) {
|
||||
// Create a new client for empty logs test
|
||||
sockPath2 := unixDomainSocketPath(t)
|
||||
stopLogCh2 := make(chan struct{})
|
||||
|
||||
go func() {
|
||||
httpLogServer(sockPath2, stopLogCh2)
|
||||
}()
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
client2 := newHTTPLogClient(sockPath2)
|
||||
logs, err := client2.GetLogs()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get empty logs: %v", err)
|
||||
}
|
||||
|
||||
if len(logs) != 0 {
|
||||
t.Errorf("Expected empty logs, got %d bytes", len(logs))
|
||||
}
|
||||
|
||||
os.Remove(sockPath2)
|
||||
})
|
||||
}
|
||||
@@ -1,7 +1,15 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
// AppCallback provides hooks for injecting certain functionalities
|
||||
// from mobile platforms to main ctrld cli.
|
||||
// This allows mobile applications to customize behavior without modifying core CLI code
|
||||
type AppCallback struct {
|
||||
HostName func() string
|
||||
LanIp func() string
|
||||
@@ -10,10 +18,94 @@ type AppCallback struct {
|
||||
}
|
||||
|
||||
// AppConfig allows overwriting ctrld cli flags from mobile platforms.
|
||||
// This provides a clean interface for mobile apps to configure ctrld behavior
|
||||
type AppConfig struct {
|
||||
CdUID string
|
||||
HomeDir string
|
||||
UpstreamProto string
|
||||
Verbose int
|
||||
LogPath string
|
||||
CdUID string
|
||||
ProvisionID string
|
||||
CustomHostname string
|
||||
HomeDir string
|
||||
UpstreamProto string
|
||||
Verbose int
|
||||
LogPath string
|
||||
}
|
||||
|
||||
// Network and HTTP configuration constants
|
||||
const (
|
||||
// defaultHTTPTimeout provides reasonable timeout for HTTP operations
|
||||
// This prevents hanging requests while allowing sufficient time for network delays
|
||||
defaultHTTPTimeout = 30 * time.Second
|
||||
|
||||
// defaultMaxRetries provides retry attempts for failed HTTP requests
|
||||
// This improves reliability in unstable network conditions
|
||||
defaultMaxRetries = 3
|
||||
|
||||
// downloadServerIp is the fallback IP for download operations
|
||||
// This ensures downloads work even when DNS resolution fails
|
||||
downloadServerIp = "23.171.240.151"
|
||||
)
|
||||
|
||||
// httpClientWithFallback returns an HTTP client configured with timeout and IPv4 fallback
|
||||
// This ensures reliable HTTP operations by preferring IPv4 and handling timeouts gracefully
|
||||
func httpClientWithFallback(timeout time.Duration) *http.Client {
|
||||
return &http.Client{
|
||||
Timeout: timeout,
|
||||
Transport: &http.Transport{
|
||||
// Prefer IPv4 over IPv6
|
||||
// This improves compatibility with networks that have IPv6 issues
|
||||
DialContext: (&net.Dialer{
|
||||
Timeout: 10 * time.Second,
|
||||
KeepAlive: 30 * time.Second,
|
||||
FallbackDelay: 1 * time.Millisecond, // Very small delay to prefer IPv4
|
||||
}).DialContext,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// doWithRetry performs an HTTP request with retries
|
||||
// This improves reliability by automatically retrying failed requests with exponential backoff
|
||||
func doWithRetry(req *http.Request, maxRetries int, ip string) (*http.Response, error) {
|
||||
var lastErr error
|
||||
client := httpClientWithFallback(defaultHTTPTimeout)
|
||||
var ipReq *http.Request
|
||||
if ip != "" {
|
||||
ipReq = req.Clone(req.Context())
|
||||
ipReq.Host = ip
|
||||
ipReq.URL.Host = ip
|
||||
}
|
||||
for attempt := 0; attempt < maxRetries; attempt++ {
|
||||
if attempt > 0 {
|
||||
// Linear backoff reduces server load and improves success rate
|
||||
time.Sleep(time.Second * time.Duration(attempt+1))
|
||||
}
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err == nil {
|
||||
return resp, nil
|
||||
}
|
||||
if ipReq != nil {
|
||||
mainLog.Load().Warn().Err(err).Msgf("Dial to %q failed", req.Host)
|
||||
mainLog.Load().Warn().Msgf("Fallback to direct ip to download prod version: %q", ip)
|
||||
resp, err = client.Do(ipReq)
|
||||
if err == nil {
|
||||
return resp, nil
|
||||
}
|
||||
}
|
||||
|
||||
lastErr = err
|
||||
mainLog.Load().Debug().Err(err).
|
||||
Str("method", req.Method).
|
||||
Str("url", req.URL.String()).
|
||||
Msgf("HTTP request attempt %d/%d failed", attempt+1, maxRetries)
|
||||
}
|
||||
return nil, fmt.Errorf("failed after %d attempts to %s %s: %v", maxRetries, req.Method, req.URL, lastErr)
|
||||
}
|
||||
|
||||
// Helper for making GET requests with retries
|
||||
// This provides a simplified interface for common GET operations with built-in retry logic
|
||||
func getWithRetry(url string, ip string) (*http.Response, error) {
|
||||
req, err := http.NewRequest(http.MethodGet, url, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return doWithRetry(req, defaultMaxRetries, ip)
|
||||
}
|
||||
|
||||
407
cmd/cli/log_writer.go
Normal file
407
cmd/cli/log_writer.go
Normal file
@@ -0,0 +1,407 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"regexp"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"go.uber.org/zap"
|
||||
"go.uber.org/zap/zapcore"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld"
|
||||
)
|
||||
|
||||
// Log writer constants for buffer management and log formatting
|
||||
const (
|
||||
// logWriterSize is the default buffer size for log writers
|
||||
// This provides sufficient space for runtime logs without excessive memory usage
|
||||
logWriterSize = 1024 * 1024 * 5 // 5 MB
|
||||
|
||||
// logWriterSmallSize is used for memory-constrained environments
|
||||
// This reduces memory footprint while still maintaining log functionality
|
||||
logWriterSmallSize = 1024 * 1024 * 1 // 1 MB
|
||||
|
||||
// logWriterInitialSize is the initial buffer allocation
|
||||
// This provides immediate space for early log entries
|
||||
logWriterInitialSize = 32 * 1024 // 32 KB
|
||||
|
||||
// logWriterSentInterval controls how often logs are sent to external systems
|
||||
// This balances real-time logging with system performance
|
||||
logWriterSentInterval = time.Minute
|
||||
|
||||
// logWriterInitEndMarker marks the end of initialization logs
|
||||
// This helps separate startup logs from runtime logs
|
||||
logWriterInitEndMarker = "\n\n=== INIT_END ===\n\n"
|
||||
|
||||
// logWriterLogEndMarker marks the end of log sections
|
||||
// This provides clear boundaries for log parsing and analysis
|
||||
logWriterLogEndMarker = "\n\n=== LOG_END ===\n\n"
|
||||
)
|
||||
|
||||
// Custom level encoders that handle NOTICE level
|
||||
// Since NOTICE and WARN share the same numeric value (1), we handle them specially
|
||||
// in the encoder to display NOTICE messages with the "NOTICE" prefix.
|
||||
// Note: WARN messages will also display as "NOTICE" because they share the same level value.
|
||||
// This is the intended behavior for visual distinction.
|
||||
|
||||
// noticeLevelEncoder provides custom level encoding for NOTICE level
|
||||
// This ensures NOTICE messages are clearly distinguished from other log levels
|
||||
func noticeLevelEncoder(l zapcore.Level, enc zapcore.PrimitiveArrayEncoder) {
|
||||
switch l {
|
||||
case ctrld.NoticeLevel:
|
||||
enc.AppendString("NOTICE")
|
||||
default:
|
||||
zapcore.CapitalLevelEncoder(l, enc)
|
||||
}
|
||||
}
|
||||
|
||||
// noticeColorLevelEncoder provides colored level encoding for NOTICE level
|
||||
// This uses cyan color to make NOTICE messages visually distinct in terminal output
|
||||
func noticeColorLevelEncoder(l zapcore.Level, enc zapcore.PrimitiveArrayEncoder) {
|
||||
switch l {
|
||||
case ctrld.NoticeLevel:
|
||||
enc.AppendString("\x1b[36mNOTICE\x1b[0m") // Cyan color for NOTICE
|
||||
default:
|
||||
zapcore.CapitalColorLevelEncoder(l, enc)
|
||||
}
|
||||
}
|
||||
|
||||
// logViewResponse represents the response structure for log viewing requests
|
||||
// This provides a consistent JSON format for log data retrieval
|
||||
type logViewResponse struct {
|
||||
Data string `json:"data"`
|
||||
}
|
||||
|
||||
// logSentResponse represents the response structure for log sending operations
|
||||
// This includes size information and error details for debugging
|
||||
type logSentResponse struct {
|
||||
Size int64 `json:"size"`
|
||||
Error string `json:"error"`
|
||||
}
|
||||
|
||||
// logReader provides read access to log data with size information.
|
||||
//
|
||||
// This struct encapsulates log reading functionality for external consumers,
|
||||
// providing both the log content and metadata about the log size. It supports
|
||||
// reading from both internal log buffers (when no external logging is configured)
|
||||
// and external log files (when logging to file is enabled).
|
||||
//
|
||||
// Fields:
|
||||
// - r: An io.ReadCloser that provides access to the log content
|
||||
// - size: The total size of the log data in bytes
|
||||
//
|
||||
// The logReader is used by the control server to serve log content to clients
|
||||
// and by various CLI commands that need to display or process log data.
|
||||
type logReader struct {
|
||||
r io.ReadCloser
|
||||
size int64
|
||||
}
|
||||
|
||||
// logWriter is an internal buffer to keep track of runtime log when no logging is enabled.
|
||||
// This provides in-memory log storage for debugging and monitoring purposes
|
||||
type logWriter struct {
|
||||
mu sync.Mutex
|
||||
buf bytes.Buffer
|
||||
size int
|
||||
}
|
||||
|
||||
// newLogWriter creates an internal log writer.
|
||||
// This provides the default log writer with standard buffer size
|
||||
func newLogWriter() *logWriter {
|
||||
return newLogWriterWithSize(logWriterSize)
|
||||
}
|
||||
|
||||
// newSmallLogWriter creates an internal log writer with small buffer size.
|
||||
// This is used in memory-constrained environments or for temporary logging
|
||||
func newSmallLogWriter() *logWriter {
|
||||
return newLogWriterWithSize(logWriterSmallSize)
|
||||
}
|
||||
|
||||
// newLogWriterWithSize creates an internal log writer with a given buffer size.
|
||||
// This allows customization of log buffer size based on specific requirements
|
||||
func newLogWriterWithSize(size int) *logWriter {
|
||||
lw := &logWriter{size: size}
|
||||
return lw
|
||||
}
|
||||
|
||||
// Write implements io.Writer interface for logWriter
|
||||
// This manages buffer overflow by discarding old data while preserving important markers
|
||||
func (lw *logWriter) Write(p []byte) (int, error) {
|
||||
lw.mu.Lock()
|
||||
defer lw.mu.Unlock()
|
||||
|
||||
// If writing p causes overflows, discard old data.
|
||||
// This prevents unbounded memory growth while maintaining recent logs
|
||||
if lw.buf.Len()+len(p) > lw.size {
|
||||
buf := lw.buf.Bytes()
|
||||
haveEndMarker := false
|
||||
// If there's init end marker already, preserve the data til the marker.
|
||||
// This ensures initialization logs are always available for debugging
|
||||
if idx := bytes.LastIndex(buf, []byte(logWriterInitEndMarker)); idx >= 0 {
|
||||
buf = buf[:idx+len(logWriterInitEndMarker)]
|
||||
haveEndMarker = true
|
||||
} else {
|
||||
// Otherwise, preserve the initial size data.
|
||||
buf = buf[:logWriterInitialSize]
|
||||
if idx := bytes.LastIndex(buf, []byte("\n")); idx != -1 {
|
||||
buf = buf[:idx]
|
||||
}
|
||||
}
|
||||
lw.buf.Reset()
|
||||
lw.buf.Write(buf)
|
||||
if !haveEndMarker {
|
||||
lw.buf.WriteString(logWriterInitEndMarker) // indicate that the log was truncated.
|
||||
}
|
||||
}
|
||||
// If p is bigger than buffer size, truncate p by half until its size is smaller.
|
||||
for len(p)+lw.buf.Len() > lw.size {
|
||||
p = p[len(p)/2:]
|
||||
}
|
||||
return lw.buf.Write(p)
|
||||
}
|
||||
|
||||
// initLogging initializes global logging setup.
|
||||
func (p *prog) initLogging(backup bool) {
|
||||
logCores := initLoggingWithBackup(backup)
|
||||
|
||||
// Initializing internal logging after global logging.
|
||||
p.initInternalLogging(logCores)
|
||||
p.logger.Store(mainLog.Load())
|
||||
}
|
||||
|
||||
// initInternalLogging performs internal logging if there's no log enabled.
|
||||
func (p *prog) initInternalLogging(externalCores []zapcore.Core) {
|
||||
if !p.needInternalLogging() {
|
||||
return
|
||||
}
|
||||
p.initInternalLogWriterOnce.Do(func() {
|
||||
p.Notice().Msg("Internal logging enabled")
|
||||
p.internalLogWriter = newLogWriter()
|
||||
p.internalLogSent = time.Now().Add(-logWriterSentInterval)
|
||||
p.internalWarnLogWriter = newSmallLogWriter()
|
||||
})
|
||||
p.mu.Lock()
|
||||
lw := p.internalLogWriter
|
||||
wlw := p.internalWarnLogWriter
|
||||
p.mu.Unlock()
|
||||
|
||||
// Create zap cores for different writers
|
||||
var cores []zapcore.Core
|
||||
cores = append(cores, externalCores...)
|
||||
|
||||
// Add core for internal log writer.
|
||||
// Run the internal logging at debug level, so we could
|
||||
// have enough information for troubleshooting.
|
||||
internalCore := newHumanReadableZapCore(lw, zapcore.DebugLevel)
|
||||
cores = append(cores, internalCore)
|
||||
|
||||
// Add core for internal warn log writer
|
||||
warnCore := newHumanReadableZapCore(wlw, zapcore.WarnLevel)
|
||||
cores = append(cores, warnCore)
|
||||
|
||||
// Create a multi-core logger
|
||||
multiCore := zapcore.NewTee(cores...)
|
||||
logger := zap.New(multiCore)
|
||||
|
||||
mainLog.Store(&ctrld.Logger{Logger: logger})
|
||||
}
|
||||
|
||||
// needInternalLogging reports whether prog needs to run internal logging.
|
||||
func (p *prog) needInternalLogging() bool {
|
||||
// Do not run in non-cd mode.
|
||||
if cdUID == "" {
|
||||
return false
|
||||
}
|
||||
// Do not run if there's already log file.
|
||||
if p.cfg.Service.LogPath != "" {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// logReaderNoColor returns a logReader with ANSI color codes stripped from the log content.
|
||||
//
|
||||
// This method is useful when log content needs to be processed by tools that don't
|
||||
// handle ANSI escape sequences properly, or when storing logs in plain text format.
|
||||
// It internally calls logReader(true) to strip color codes.
|
||||
//
|
||||
// Returns:
|
||||
// - *logReader: A logReader instance with color codes removed, or nil if no logs available
|
||||
// - error: Any error encountered during log reading (e.g., empty logs, file access issues)
|
||||
//
|
||||
// Use cases:
|
||||
// - Log processing pipelines that require plain text
|
||||
// - Storing logs in databases or text files
|
||||
// - Displaying logs in environments that don't support color
|
||||
func (p *prog) logReaderNoColor() (*logReader, error) {
|
||||
return p.logReader(true)
|
||||
}
|
||||
|
||||
// logReaderRaw returns a logReader with ANSI color codes preserved in the log content.
|
||||
//
|
||||
// This method maintains the original formatting of log entries including color codes,
|
||||
// which is useful for displaying logs in terminals that support ANSI colors or when
|
||||
// the original visual formatting needs to be preserved. It internally calls logReader(false).
|
||||
//
|
||||
// Returns:
|
||||
// - *logReader: A logReader instance with color codes preserved, or nil if no logs available
|
||||
// - error: Any error encountered during log reading (e.g., empty logs, file access issues)
|
||||
//
|
||||
// Use cases:
|
||||
// - Terminal-based log viewers that support color
|
||||
// - Interactive debugging sessions
|
||||
// - Preserving original log formatting for display
|
||||
func (p *prog) logReaderRaw() (*logReader, error) {
|
||||
return p.logReader(false)
|
||||
}
|
||||
|
||||
// logReader creates a logReader instance for accessing log content with optional color stripping.
|
||||
//
|
||||
// This is the core method that handles log reading from different sources based on the
|
||||
// current logging configuration. It supports both internal logging (when no external
|
||||
// logging is configured) and external file logging (when logging to file is enabled).
|
||||
//
|
||||
// Behavior:
|
||||
// - Internal logging: Reads from internal log buffers (normal logs + warning logs)
|
||||
// and combines them with appropriate markers for separation
|
||||
// - External logging: Reads directly from the configured log file
|
||||
// - Empty logs: Returns appropriate error messages when no log content is available
|
||||
//
|
||||
// Parameters:
|
||||
// - stripColor: If true, removes ANSI color codes from log content; if false, preserves them
|
||||
//
|
||||
// Returns:
|
||||
// - *logReader: A logReader instance providing access to log content and size metadata
|
||||
// - error: Any error encountered during log reading, including:
|
||||
// - "nil internal log writer" - Internal logging not properly initialized
|
||||
// - "nil internal warn log writer" - Warning log writer not properly initialized
|
||||
// - "internal log is empty" - No content in internal log buffers
|
||||
// - "log file is empty" - External log file exists but contains no data
|
||||
// - File system errors when accessing external log files
|
||||
//
|
||||
// The method handles thread-safe access to internal log buffers and provides
|
||||
// comprehensive error handling for various edge cases.
|
||||
func (p *prog) logReader(stripColor bool) (*logReader, error) {
|
||||
if p.needInternalLogging() {
|
||||
p.mu.Lock()
|
||||
lw := p.internalLogWriter
|
||||
wlw := p.internalWarnLogWriter
|
||||
p.mu.Unlock()
|
||||
if lw == nil {
|
||||
return nil, errors.New("nil internal log writer")
|
||||
}
|
||||
if wlw == nil {
|
||||
return nil, errors.New("nil internal warn log writer")
|
||||
}
|
||||
|
||||
// Normal log content.
|
||||
lw.mu.Lock()
|
||||
lwReader := newLogReader(&lw.buf, stripColor)
|
||||
lwSize := lw.buf.Len()
|
||||
lw.mu.Unlock()
|
||||
// Warn log content.
|
||||
wlw.mu.Lock()
|
||||
wlwReader := newLogReader(&wlw.buf, stripColor)
|
||||
wlwSize := wlw.buf.Len()
|
||||
wlw.mu.Unlock()
|
||||
reader := io.MultiReader(lwReader, bytes.NewReader([]byte(logWriterLogEndMarker)), wlwReader)
|
||||
lr := &logReader{r: io.NopCloser(reader)}
|
||||
lr.size = int64(lwSize + wlwSize)
|
||||
if lr.size == 0 {
|
||||
return nil, errors.New("internal log is empty")
|
||||
}
|
||||
return lr, nil
|
||||
}
|
||||
if p.cfg.Service.LogPath == "" {
|
||||
return &logReader{r: io.NopCloser(strings.NewReader(""))}, nil
|
||||
}
|
||||
f, err := os.Open(normalizeLogFilePath(p.cfg.Service.LogPath))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
lr := &logReader{r: f}
|
||||
if st, err := f.Stat(); err == nil {
|
||||
lr.size = st.Size()
|
||||
} else {
|
||||
return nil, fmt.Errorf("f.Stat: %w", err)
|
||||
}
|
||||
if lr.size == 0 {
|
||||
return nil, errors.New("log file is empty")
|
||||
}
|
||||
return lr, nil
|
||||
}
|
||||
|
||||
// newHumanReadableZapCore creates a zap core optimized for human-readable log output.
|
||||
//
|
||||
// Features:
|
||||
// - Uses development encoder configuration for enhanced readability
|
||||
// - Console encoding with colored log levels for easy visual scanning
|
||||
// - Millisecond precision timestamps in human-friendly format
|
||||
// - Structured field output with clear key-value pairs
|
||||
// - Ideal for development, debugging, and interactive terminal sessions
|
||||
//
|
||||
// Parameters:
|
||||
// - w: The output writer (e.g., os.Stdout, file, buffer)
|
||||
// - level: Minimum log level to capture (e.g., Debug, Info, Warn, Error)
|
||||
//
|
||||
// Returns a zapcore.Core configured for human consumption.
|
||||
func newHumanReadableZapCore(w io.Writer, level zapcore.Level) zapcore.Core {
|
||||
encoderConfig := zap.NewDevelopmentEncoderConfig()
|
||||
encoderConfig.TimeKey = "time"
|
||||
encoderConfig.EncodeTime = zapcore.TimeEncoderOfLayout(time.StampMilli)
|
||||
encoderConfig.EncodeLevel = noticeColorLevelEncoder
|
||||
encoder := zapcore.NewConsoleEncoder(encoderConfig)
|
||||
return zapcore.NewCore(encoder, zapcore.AddSync(w), level)
|
||||
}
|
||||
|
||||
// newMachineFriendlyZapCore creates a zap core optimized for machine processing and log aggregation.
|
||||
//
|
||||
// Features:
|
||||
// - Uses production encoder configuration for consistent, parseable output
|
||||
// - Console encoding with non-colored log levels for log parsing tools
|
||||
// - Millisecond precision timestamps in ISO-like format
|
||||
// - Structured field output optimized for log aggregation systems
|
||||
// - Ideal for production environments, log shipping, and automated analysis
|
||||
//
|
||||
// Parameters:
|
||||
// - w: The output writer (e.g., os.Stdout, file, buffer)
|
||||
// - level: Minimum log level to capture (e.g., Debug, Info, Warn, Error)
|
||||
//
|
||||
// Returns a zapcore.Core configured for machine consumption and log aggregation.
|
||||
func newMachineFriendlyZapCore(w io.Writer, level zapcore.Level) zapcore.Core {
|
||||
encoderConfig := zap.NewProductionEncoderConfig()
|
||||
encoderConfig.TimeKey = "time"
|
||||
encoderConfig.EncodeTime = zapcore.TimeEncoderOfLayout(time.StampMilli)
|
||||
encoderConfig.EncodeLevel = noticeLevelEncoder
|
||||
encoder := zapcore.NewConsoleEncoder(encoderConfig)
|
||||
return zapcore.NewCore(encoder, zapcore.AddSync(w), level)
|
||||
}
|
||||
|
||||
// ansiRegex is a regular expression to match ANSI color codes.
|
||||
var ansiRegex = regexp.MustCompile(`\x1b\[[0-9;]*m`)
|
||||
|
||||
// newLogReader creates a reader for log buffer content with optional ANSI color stripping.
|
||||
//
|
||||
// This function provides flexible log content access by allowing consumers to choose
|
||||
// between raw log data (with ANSI color codes) or stripped content (without color codes).
|
||||
// The color stripping is useful when logs need to be processed by tools that don't
|
||||
// handle ANSI escape sequences properly, or when storing logs in plain text format.
|
||||
//
|
||||
// Parameters:
|
||||
// - buf: The log buffer containing the log data to read
|
||||
// - stripColor: If true, strips ANSI color codes from the log content;
|
||||
// if false, returns raw log content with color codes preserved
|
||||
//
|
||||
// Returns an io.Reader that provides access to the processed log content.
|
||||
func newLogReader(buf *bytes.Buffer, stripColor bool) io.Reader {
|
||||
if stripColor {
|
||||
return strings.NewReader(ansiRegex.ReplaceAllString(buf.String(), ""))
|
||||
}
|
||||
return strings.NewReader(buf.String())
|
||||
}
|
||||
417
cmd/cli/log_writer_test.go
Normal file
417
cmd/cli/log_writer_test.go
Normal file
@@ -0,0 +1,417 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"go.uber.org/zap"
|
||||
"go.uber.org/zap/zapcore"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld"
|
||||
)
|
||||
|
||||
func Test_logWriter_Write(t *testing.T) {
|
||||
size := 64 * 1024
|
||||
lw := &logWriter{size: size}
|
||||
lw.buf.Grow(lw.size)
|
||||
data := strings.Repeat("A", size)
|
||||
lw.Write([]byte(data))
|
||||
if lw.buf.String() != data {
|
||||
t.Fatalf("unexpected buf content: %v", lw.buf.String())
|
||||
}
|
||||
newData := "B"
|
||||
halfData := strings.Repeat("A", len(data)/2) + logWriterInitEndMarker
|
||||
lw.Write([]byte(newData))
|
||||
if lw.buf.String() != halfData+newData {
|
||||
t.Fatalf("unexpected new buf content: %v", lw.buf.String())
|
||||
}
|
||||
|
||||
bigData := strings.Repeat("B", 256*1024)
|
||||
expected := halfData + strings.Repeat("B", 16*1024)
|
||||
lw.Write([]byte(bigData))
|
||||
if lw.buf.String() != expected {
|
||||
t.Fatalf("unexpected big buf content: %v", lw.buf.String())
|
||||
}
|
||||
}
|
||||
|
||||
func Test_logWriter_ConcurrentWrite(t *testing.T) {
|
||||
size := 64 * 1024
|
||||
lw := &logWriter{size: size}
|
||||
n := 10
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(n)
|
||||
for i := 0; i < n; i++ {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
lw.Write([]byte(strings.Repeat("A", i)))
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
if lw.buf.Len() > lw.size {
|
||||
t.Fatalf("unexpected buf size: %v, content: %q", lw.buf.Len(), lw.buf.String())
|
||||
}
|
||||
}
|
||||
|
||||
func Test_logWriter_MarkerInitEnd(t *testing.T) {
|
||||
size := 64 * 1024
|
||||
lw := &logWriter{size: size}
|
||||
lw.buf.Grow(lw.size)
|
||||
|
||||
paddingSize := 10
|
||||
// Writing half of the size, minus len(end marker) and padding size.
|
||||
dataSize := size/2 - len(logWriterInitEndMarker) - paddingSize
|
||||
data := strings.Repeat("A", dataSize)
|
||||
// Inserting newline for making partial init data
|
||||
data += "\n"
|
||||
// Filling left over buffer to make the log full.
|
||||
// The data length: len(end marker) + padding size - 1 (for newline above) + size/2
|
||||
data += strings.Repeat("A", len(logWriterInitEndMarker)+paddingSize-1+(size/2))
|
||||
lw.Write([]byte(data))
|
||||
if lw.buf.String() != data {
|
||||
t.Fatalf("unexpected buf content: %v", lw.buf.String())
|
||||
}
|
||||
lw.Write([]byte("B"))
|
||||
lw.Write([]byte(strings.Repeat("B", 256*1024)))
|
||||
firstIdx := strings.Index(lw.buf.String(), logWriterInitEndMarker)
|
||||
lastIdx := strings.LastIndex(lw.buf.String(), logWriterInitEndMarker)
|
||||
// Check if init end marker present.
|
||||
if firstIdx == -1 || lastIdx == -1 {
|
||||
t.Fatalf("missing init end marker: %s", lw.buf.String())
|
||||
}
|
||||
// Check if init end marker appears only once.
|
||||
if firstIdx != lastIdx {
|
||||
t.Fatalf("log init end marker appears more than once: %s", lw.buf.String())
|
||||
}
|
||||
// Ensure that we have the correct init log data.
|
||||
if !strings.Contains(lw.buf.String(), strings.Repeat("A", dataSize)+logWriterInitEndMarker) {
|
||||
t.Fatalf("unexpected log content: %s", lw.buf.String())
|
||||
}
|
||||
}
|
||||
|
||||
// TestNoticeLevel tests that the custom NOTICE level works correctly
|
||||
func TestNoticeLevel(t *testing.T) {
|
||||
// Create a buffer to capture log output
|
||||
var buf bytes.Buffer
|
||||
|
||||
// Create encoder config with custom NOTICE level support
|
||||
encoderConfig := zap.NewDevelopmentEncoderConfig()
|
||||
encoderConfig.TimeKey = "time"
|
||||
encoderConfig.EncodeTime = zapcore.TimeEncoderOfLayout("15:04:05.000")
|
||||
encoderConfig.EncodeLevel = noticeLevelEncoder
|
||||
|
||||
// Test with NOTICE level
|
||||
encoder := zapcore.NewConsoleEncoder(encoderConfig)
|
||||
core := zapcore.NewCore(encoder, zapcore.AddSync(&buf), ctrld.NoticeLevel)
|
||||
logger := zap.New(core)
|
||||
ctrldLogger := &ctrld.Logger{Logger: logger}
|
||||
|
||||
// Log messages at different levels
|
||||
ctrldLogger.Debug().Msg("This is a DEBUG message")
|
||||
ctrldLogger.Info().Msg("This is an INFO message")
|
||||
ctrldLogger.Notice().Msg("This is a NOTICE message")
|
||||
ctrldLogger.Warn().Msg("This is a WARN message")
|
||||
ctrldLogger.Error().Msg("This is an ERROR message")
|
||||
|
||||
output := buf.String()
|
||||
|
||||
// Verify that DEBUG and INFO messages are NOT logged (filtered out)
|
||||
if strings.Contains(output, "DEBUG") {
|
||||
t.Error("DEBUG message should not be logged when level is NOTICE")
|
||||
}
|
||||
if strings.Contains(output, "INFO") {
|
||||
t.Error("INFO message should not be logged when level is NOTICE")
|
||||
}
|
||||
|
||||
// Verify that NOTICE, WARN, and ERROR messages ARE logged
|
||||
if !strings.Contains(output, "NOTICE") {
|
||||
t.Error("NOTICE message should be logged when level is NOTICE")
|
||||
}
|
||||
if !strings.Contains(output, "WARN") {
|
||||
t.Error("WARN message should be logged when level is NOTICE")
|
||||
}
|
||||
if !strings.Contains(output, "ERROR") {
|
||||
t.Error("ERROR message should be logged when level is NOTICE")
|
||||
}
|
||||
|
||||
// Verify the NOTICE message content
|
||||
if !strings.Contains(output, "This is a NOTICE message") {
|
||||
t.Error("NOTICE message content should be present")
|
||||
}
|
||||
|
||||
t.Logf("Log output with NOTICE level:\n%s", output)
|
||||
}
|
||||
|
||||
func TestNewLogReader(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
bufContent string
|
||||
stripColor bool
|
||||
expected string
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "empty_buffer_no_color_strip",
|
||||
bufContent: "",
|
||||
stripColor: false,
|
||||
expected: "",
|
||||
description: "Empty buffer should return empty reader",
|
||||
},
|
||||
{
|
||||
name: "empty_buffer_with_color_strip",
|
||||
bufContent: "",
|
||||
stripColor: true,
|
||||
expected: "",
|
||||
description: "Empty buffer with color strip should return empty reader",
|
||||
},
|
||||
{
|
||||
name: "plain_text_no_color_strip",
|
||||
bufContent: "This is plain text without any color codes",
|
||||
stripColor: false,
|
||||
expected: "This is plain text without any color codes",
|
||||
description: "Plain text should be returned as-is when not stripping colors",
|
||||
},
|
||||
{
|
||||
name: "plain_text_with_color_strip",
|
||||
bufContent: "This is plain text without any color codes",
|
||||
stripColor: true,
|
||||
expected: "This is plain text without any color codes",
|
||||
description: "Plain text should be returned as-is when stripping colors",
|
||||
},
|
||||
{
|
||||
name: "text_with_ansi_codes_no_strip",
|
||||
bufContent: "Normal text \x1b[31mred text\x1b[0m normal again",
|
||||
stripColor: false,
|
||||
expected: "Normal text \x1b[31mred text\x1b[0m normal again",
|
||||
description: "ANSI color codes should be preserved when not stripping",
|
||||
},
|
||||
{
|
||||
name: "text_with_ansi_codes_with_strip",
|
||||
bufContent: "Normal text \x1b[31mred text\x1b[0m normal again",
|
||||
stripColor: true,
|
||||
expected: "Normal text red text normal again",
|
||||
description: "ANSI color codes should be removed when stripping colors",
|
||||
},
|
||||
{
|
||||
name: "multiple_ansi_codes_no_strip",
|
||||
bufContent: "\x1b[1mBold\x1b[0m \x1b[32mGreen\x1b[0m \x1b[34mBlue\x1b[0m text",
|
||||
stripColor: false,
|
||||
expected: "\x1b[1mBold\x1b[0m \x1b[32mGreen\x1b[0m \x1b[34mBlue\x1b[0m text",
|
||||
description: "Multiple ANSI codes should be preserved when not stripping",
|
||||
},
|
||||
{
|
||||
name: "multiple_ansi_codes_with_strip",
|
||||
bufContent: "\x1b[1mBold\x1b[0m \x1b[32mGreen\x1b[0m \x1b[34mBlue\x1b[0m text",
|
||||
stripColor: true,
|
||||
expected: "Bold Green Blue text",
|
||||
description: "Multiple ANSI codes should be removed when stripping colors",
|
||||
},
|
||||
{
|
||||
name: "complex_ansi_sequences_no_strip",
|
||||
bufContent: "\x1b[1;31;42mBold red on green\x1b[0m \x1b[38;5;208mOrange\x1b[0m",
|
||||
stripColor: false,
|
||||
expected: "\x1b[1;31;42mBold red on green\x1b[0m \x1b[38;5;208mOrange\x1b[0m",
|
||||
description: "Complex ANSI sequences should be preserved when not stripping",
|
||||
},
|
||||
{
|
||||
name: "complex_ansi_sequences_with_strip",
|
||||
bufContent: "\x1b[1;31;42mBold red on green\x1b[0m \x1b[38;5;208mOrange\x1b[0m",
|
||||
stripColor: true,
|
||||
expected: "Bold red on green Orange",
|
||||
description: "Complex ANSI sequences should be removed when stripping colors",
|
||||
},
|
||||
{
|
||||
name: "ansi_codes_with_newlines_no_strip",
|
||||
bufContent: "Line 1\n\x1b[31mRed line\x1b[0m\nLine 3",
|
||||
stripColor: false,
|
||||
expected: "Line 1\n\x1b[31mRed line\x1b[0m\nLine 3",
|
||||
description: "ANSI codes with newlines should be preserved when not stripping",
|
||||
},
|
||||
{
|
||||
name: "ansi_codes_with_newlines_with_strip",
|
||||
bufContent: "Line 1\n\x1b[31mRed line\x1b[0m\nLine 3",
|
||||
stripColor: true,
|
||||
expected: "Line 1\nRed line\nLine 3",
|
||||
description: "ANSI codes with newlines should be removed when stripping colors",
|
||||
},
|
||||
{
|
||||
name: "malformed_ansi_codes_no_strip",
|
||||
bufContent: "Text \x1b[invalidm \x1b[0m normal",
|
||||
stripColor: false,
|
||||
expected: "Text \x1b[invalidm \x1b[0m normal",
|
||||
description: "Malformed ANSI codes should be preserved when not stripping",
|
||||
},
|
||||
{
|
||||
name: "malformed_ansi_codes_with_strip",
|
||||
bufContent: "Text \x1b[invalidm \x1b[0m normal",
|
||||
stripColor: true,
|
||||
expected: "Text \x1b[invalidm normal",
|
||||
description: "Non-matching ANSI sequences should be preserved when stripping colors",
|
||||
},
|
||||
{
|
||||
name: "large_buffer_no_strip",
|
||||
bufContent: strings.Repeat("A", 10000) + "\x1b[31m" + strings.Repeat("B", 1000) + "\x1b[0m",
|
||||
stripColor: false,
|
||||
expected: strings.Repeat("A", 10000) + "\x1b[31m" + strings.Repeat("B", 1000) + "\x1b[0m",
|
||||
description: "Large buffer should handle ANSI codes correctly when not stripping",
|
||||
},
|
||||
{
|
||||
name: "large_buffer_with_strip",
|
||||
bufContent: strings.Repeat("A", 10000) + "\x1b[31m" + strings.Repeat("B", 1000) + "\x1b[0m",
|
||||
stripColor: true,
|
||||
expected: strings.Repeat("A", 10000) + strings.Repeat("B", 1000),
|
||||
description: "Large buffer should remove ANSI codes correctly when stripping",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Create a buffer with the test content
|
||||
buf := &bytes.Buffer{}
|
||||
buf.WriteString(tt.bufContent)
|
||||
|
||||
// Create the log reader
|
||||
reader := newLogReader(buf, tt.stripColor)
|
||||
|
||||
// Read all content from the reader
|
||||
content, err := io.ReadAll(reader)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to read from log reader: %v", err)
|
||||
}
|
||||
|
||||
// Verify the content matches expected
|
||||
actual := string(content)
|
||||
if actual != tt.expected {
|
||||
t.Errorf("Expected content: %q, got: %q", tt.expected, actual)
|
||||
t.Logf("Description: %s", tt.description)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewLogReader_ReaderBehavior(t *testing.T) {
|
||||
// Test that the returned reader behaves correctly
|
||||
buf := &bytes.Buffer{}
|
||||
buf.WriteString("Test content with \x1b[31mred\x1b[0m text")
|
||||
|
||||
// Test with color stripping
|
||||
reader := newLogReader(buf, true)
|
||||
|
||||
// Test reading in chunks
|
||||
chunk1 := make([]byte, 10)
|
||||
n1, err := reader.Read(chunk1)
|
||||
if err != nil && err != io.EOF {
|
||||
t.Fatalf("Unexpected error reading first chunk: %v", err)
|
||||
}
|
||||
if n1 != 10 {
|
||||
t.Errorf("Expected to read 10 bytes, got %d", n1)
|
||||
}
|
||||
|
||||
// Test reading remaining content
|
||||
remaining, err := io.ReadAll(reader)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to read remaining content: %v", err)
|
||||
}
|
||||
|
||||
// Verify total content
|
||||
totalContent := string(chunk1[:n1]) + string(remaining)
|
||||
expected := "Test content with red text"
|
||||
if totalContent != expected {
|
||||
t.Errorf("Expected total content: %q, got: %q", expected, totalContent)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewLogReader_ConcurrentAccess(t *testing.T) {
|
||||
// Test concurrent access to the same buffer
|
||||
buf := &bytes.Buffer{}
|
||||
buf.WriteString("Concurrent test with \x1b[32mgreen\x1b[0m text")
|
||||
|
||||
var wg sync.WaitGroup
|
||||
numGoroutines := 10
|
||||
results := make(chan string, numGoroutines)
|
||||
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
reader := newLogReader(buf, true)
|
||||
content, err := io.ReadAll(reader)
|
||||
if err != nil {
|
||||
t.Errorf("Failed to read content: %v", err)
|
||||
return
|
||||
}
|
||||
results <- string(content)
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
close(results)
|
||||
|
||||
// Verify all goroutines got the same result
|
||||
expected := "Concurrent test with green text"
|
||||
for result := range results {
|
||||
if result != expected {
|
||||
t.Errorf("Expected: %q, got: %q", expected, result)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewLogReader_ANSIRegexEdgeCases(t *testing.T) {
|
||||
// Test edge cases for ANSI regex matching
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "empty_escape_sequence",
|
||||
input: "Text \x1b[m normal",
|
||||
expected: "Text normal",
|
||||
},
|
||||
{
|
||||
name: "multiple_semicolons",
|
||||
input: "Text \x1b[1;2;3;4m normal",
|
||||
expected: "Text normal",
|
||||
},
|
||||
{
|
||||
name: "numeric_only",
|
||||
input: "Text \x1b[123m normal",
|
||||
expected: "Text normal",
|
||||
},
|
||||
{
|
||||
name: "mixed_numeric_semicolon",
|
||||
input: "Text \x1b[1;23;456m normal",
|
||||
expected: "Text normal",
|
||||
},
|
||||
{
|
||||
name: "no_closing_bracket",
|
||||
input: "Text \x1b[31 normal",
|
||||
expected: "Text \x1b[31 normal",
|
||||
},
|
||||
{
|
||||
name: "no_opening_bracket",
|
||||
input: "Text 31m normal",
|
||||
expected: "Text 31m normal",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
buf := &bytes.Buffer{}
|
||||
buf.WriteString(tt.input)
|
||||
|
||||
reader := newLogReader(buf, true)
|
||||
content, err := io.ReadAll(reader)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to read content: %v", err)
|
||||
}
|
||||
|
||||
actual := string(content)
|
||||
if actual != tt.expected {
|
||||
t.Errorf("Expected: %q, got: %q", tt.expected, actual)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -84,7 +84,7 @@ func (p *prog) detectLoop(msg *dns.Msg) {
|
||||
//
|
||||
// See: https://thekelleys.org.uk/dnsmasq/docs/dnsmasq-man.html
|
||||
func (p *prog) checkDnsLoop() {
|
||||
mainLog.Load().Debug().Msg("start checking DNS loop")
|
||||
p.Debug().Msg("Start checking DNS loop")
|
||||
upstream := make(map[string]*ctrld.UpstreamConfig)
|
||||
p.loopMu.Lock()
|
||||
for n, uc := range p.cfg.Upstream {
|
||||
@@ -93,7 +93,7 @@ func (p *prog) checkDnsLoop() {
|
||||
}
|
||||
// Do not send test query to external upstream.
|
||||
if !canBeLocalUpstream(uc.Domain) {
|
||||
mainLog.Load().Debug().Msgf("skipping external: upstream.%s", n)
|
||||
p.Debug().Msgf("Skipping external: upstream.%s", n)
|
||||
continue
|
||||
}
|
||||
uid := uc.UID()
|
||||
@@ -102,6 +102,7 @@ func (p *prog) checkDnsLoop() {
|
||||
}
|
||||
p.loopMu.Unlock()
|
||||
|
||||
loggerCtx := ctrld.LoggerCtx(context.Background(), p.logger.Load())
|
||||
for uid := range p.loop {
|
||||
msg := loopTestMsg(uid)
|
||||
uc := upstream[uid]
|
||||
@@ -109,16 +110,16 @@ func (p *prog) checkDnsLoop() {
|
||||
if uc == nil {
|
||||
continue
|
||||
}
|
||||
resolver, err := ctrld.NewResolver(uc)
|
||||
resolver, err := ctrld.NewResolver(loggerCtx, uc)
|
||||
if err != nil {
|
||||
mainLog.Load().Warn().Err(err).Msgf("could not perform loop check for upstream: %q, endpoint: %q", uc.Name, uc.Endpoint)
|
||||
p.Warn().Err(err).Msgf("Could not perform loop check for upstream: %q, endpoint: %q", uc.Name, uc.Endpoint)
|
||||
continue
|
||||
}
|
||||
if _, err := resolver.Resolve(context.Background(), msg); err != nil {
|
||||
mainLog.Load().Warn().Err(err).Msgf("could not send DNS loop check query for upstream: %q, endpoint: %q", uc.Name, uc.Endpoint)
|
||||
p.Warn().Err(err).Msgf("Could not send DNS loop check query for upstream: %q, endpoint: %q", uc.Name, uc.Endpoint)
|
||||
}
|
||||
}
|
||||
mainLog.Load().Debug().Msg("end checking DNS loop")
|
||||
p.Debug().Msg("End checking DNS loop")
|
||||
}
|
||||
|
||||
// checkDnsLoopTicker performs p.checkDnsLoop every minute.
|
||||
@@ -137,7 +138,7 @@ func (p *prog) checkDnsLoopTicker(ctx context.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
// loopTestMsg generates DNS message for checking loop.
|
||||
// loopTestMsg creates a DNS test message for loop detection
|
||||
func loopTestMsg(uid string) *dns.Msg {
|
||||
msg := new(dns.Msg)
|
||||
msg.SetQuestion(dns.Fqdn(uid+loopTestDomain), loopTestQtype)
|
||||
|
||||
156
cmd/cli/main.go
156
cmd/cli/main.go
@@ -5,14 +5,16 @@ import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/kardianos/service"
|
||||
"github.com/rs/zerolog"
|
||||
"go.uber.org/zap"
|
||||
"go.uber.org/zap/zapcore"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld"
|
||||
)
|
||||
|
||||
// Global variables for CLI configuration and state management
|
||||
// These are used across multiple commands and need to persist throughout the application lifecycle
|
||||
var (
|
||||
configPath string
|
||||
configBase64 string
|
||||
@@ -29,6 +31,7 @@ var (
|
||||
silent bool
|
||||
cdUID string
|
||||
cdOrg string
|
||||
customHostname string
|
||||
cdDev bool
|
||||
iface string
|
||||
ifaceStartStop string
|
||||
@@ -38,38 +41,46 @@ var (
|
||||
skipSelfChecks bool
|
||||
cleanup bool
|
||||
startOnly bool
|
||||
rfc1918 bool
|
||||
|
||||
mainLog atomic.Pointer[zerolog.Logger]
|
||||
consoleWriter zerolog.ConsoleWriter
|
||||
noConfigStart bool
|
||||
mainLog atomic.Pointer[ctrld.Logger]
|
||||
consoleWriter zapcore.Core
|
||||
consoleWriterLevel zapcore.Level
|
||||
noConfigStart bool
|
||||
)
|
||||
|
||||
// Flag name constants for consistent reference across the codebase
|
||||
// Using constants prevents typos and makes refactoring easier
|
||||
const (
|
||||
cdUidFlagName = "cd"
|
||||
cdOrgFlagName = "cd-org"
|
||||
nextdnsFlagName = "nextdns"
|
||||
cdUidFlagName = "cd"
|
||||
cdOrgFlagName = "cd-org"
|
||||
customHostnameFlagName = "custom-hostname"
|
||||
nextdnsFlagName = "nextdns"
|
||||
)
|
||||
|
||||
// init initializes the default logger before any CLI commands are executed
|
||||
// This ensures logging is available even during early initialization phases
|
||||
func init() {
|
||||
l := zerolog.New(io.Discard)
|
||||
mainLog.Store(&l)
|
||||
l := zap.NewNop()
|
||||
mainLog.Store(&ctrld.Logger{Logger: l})
|
||||
}
|
||||
|
||||
// Main is the entry point for the CLI application
|
||||
// It initializes configuration, sets up the CLI structure, and executes the root command
|
||||
func Main() {
|
||||
ctrld.InitConfig(v, "ctrld")
|
||||
initCLI()
|
||||
rootCmd := initCLI()
|
||||
if err := rootCmd.Execute(); err != nil {
|
||||
mainLog.Load().Error().Msg(err.Error())
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
// normalizeLogFilePath converts relative log file paths to absolute paths
|
||||
// This ensures log files are created in predictable locations regardless of working directory
|
||||
func normalizeLogFilePath(logFilePath string) string {
|
||||
// In cleanup mode, we always want the full log file path.
|
||||
if !cleanup {
|
||||
if logFilePath == "" || filepath.IsAbs(logFilePath) || service.Interactive() {
|
||||
return logFilePath
|
||||
}
|
||||
if logFilePath == "" || filepath.IsAbs(logFilePath) || service.Interactive() {
|
||||
return logFilePath
|
||||
}
|
||||
if homedir != "" {
|
||||
return filepath.Join(homedir, logFilePath)
|
||||
@@ -82,29 +93,36 @@ func normalizeLogFilePath(logFilePath string) string {
|
||||
}
|
||||
|
||||
// initConsoleLogging initializes console logging, then storing to mainLog.
|
||||
// This sets up human-readable logging output for interactive use
|
||||
func initConsoleLogging() {
|
||||
consoleWriter = zerolog.NewConsoleWriter(func(w *zerolog.ConsoleWriter) {
|
||||
w.TimeFormat = time.StampMilli
|
||||
})
|
||||
multi := zerolog.MultiLevelWriter(consoleWriter)
|
||||
l := mainLog.Load().Output(multi).With().Timestamp().Logger()
|
||||
mainLog.Store(&l)
|
||||
consoleWriterLevel = ctrld.NoticeLevel
|
||||
switch {
|
||||
case silent:
|
||||
zerolog.SetGlobalLevel(zerolog.NoLevel)
|
||||
// For silent mode, use a no-op logger to suppress all output
|
||||
l := zap.NewNop()
|
||||
mainLog.Store(&ctrld.Logger{Logger: l})
|
||||
case verbose == 1:
|
||||
zerolog.SetGlobalLevel(zerolog.InfoLevel)
|
||||
// Info level provides basic operational information
|
||||
consoleWriterLevel = zapcore.InfoLevel
|
||||
case verbose > 1:
|
||||
zerolog.SetGlobalLevel(zerolog.DebugLevel)
|
||||
default:
|
||||
zerolog.SetGlobalLevel(zerolog.NoticeLevel)
|
||||
// Debug level provides detailed diagnostic information
|
||||
consoleWriterLevel = zapcore.DebugLevel
|
||||
}
|
||||
consoleWriter = newHumanReadableZapCore(os.Stdout, consoleWriterLevel)
|
||||
l := zap.New(consoleWriter)
|
||||
mainLog.Store(&ctrld.Logger{Logger: l})
|
||||
}
|
||||
|
||||
// initLogging initializes global logging setup.
|
||||
func initLogging() {
|
||||
zerolog.TimeFieldFormat = time.RFC3339 + ".000"
|
||||
initLoggingWithBackup(true)
|
||||
// initInteractiveLogging is like initLogging, but the ProxyLogger is discarded
|
||||
// to be used for all interactive commands.
|
||||
//
|
||||
// Current log file config will also be ignored.
|
||||
// This prevents log file conflicts during interactive command execution
|
||||
func initInteractiveLogging() {
|
||||
old := cfg.Service.LogPath
|
||||
cfg.Service.LogPath = ""
|
||||
initLoggingWithBackup(false)
|
||||
cfg.Service.LogPath = old
|
||||
}
|
||||
|
||||
// initLoggingWithBackup initializes log setup base on current config.
|
||||
@@ -113,67 +131,101 @@ func initLogging() {
|
||||
// This is only used in runCmd for special handling in case of logging config
|
||||
// change in cd mode. Without special reason, the caller should use initLogging
|
||||
// wrapper instead of calling this function directly.
|
||||
func initLoggingWithBackup(doBackup bool) {
|
||||
writers := []io.Writer{io.Discard}
|
||||
func initLoggingWithBackup(doBackup bool) []zapcore.Core {
|
||||
var writers []io.Writer
|
||||
if logFilePath := normalizeLogFilePath(cfg.Service.LogPath); logFilePath != "" {
|
||||
// Create parent directory if necessary.
|
||||
// This ensures log files can be created even if the directory doesn't exist
|
||||
if err := os.MkdirAll(filepath.Dir(logFilePath), 0750); err != nil {
|
||||
mainLog.Load().Error().Msgf("failed to create log path: %v", err)
|
||||
mainLog.Load().Error().Msgf("Failed to create log path: %v", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
// Default open log file in append mode.
|
||||
// This preserves existing log entries across restarts
|
||||
flags := os.O_CREATE | os.O_RDWR | os.O_APPEND
|
||||
if doBackup {
|
||||
// Backup old log file with .1 suffix.
|
||||
// This prevents log file corruption during rotation
|
||||
if err := os.Rename(logFilePath, logFilePath+oldLogSuffix); err != nil && !os.IsNotExist(err) {
|
||||
mainLog.Load().Error().Msgf("could not backup old log file: %v", err)
|
||||
mainLog.Load().Error().Msgf("Could not backup old log file: %v", err)
|
||||
} else {
|
||||
// Backup was created, set flags for truncating old log file.
|
||||
// This ensures a clean start for the new log file
|
||||
flags = os.O_CREATE | os.O_RDWR
|
||||
}
|
||||
}
|
||||
logFile, err := openLogFile(logFilePath, flags)
|
||||
if err != nil {
|
||||
mainLog.Load().Error().Msgf("failed to create log file: %v", err)
|
||||
mainLog.Load().Error().Msgf("Failed to create log file: %v", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
writers = append(writers, logFile)
|
||||
}
|
||||
writers = append(writers, consoleWriter)
|
||||
multi := zerolog.MultiLevelWriter(writers...)
|
||||
l := mainLog.Load().Output(multi).With().Logger()
|
||||
mainLog.Store(&l)
|
||||
// TODO: find a better way.
|
||||
ctrld.ProxyLogger.Store(&l)
|
||||
|
||||
zerolog.SetGlobalLevel(zerolog.NoticeLevel)
|
||||
// Create zap cores for different writers
|
||||
// Multiple cores allow logging to both console and file simultaneously
|
||||
var cores []zapcore.Core
|
||||
cores = append(cores, consoleWriter)
|
||||
|
||||
// Determine log level based on verbosity and configuration
|
||||
// This provides flexible logging control for different use cases
|
||||
logLevel := cfg.Service.LogLevel
|
||||
switch {
|
||||
case silent:
|
||||
zerolog.SetGlobalLevel(zerolog.NoLevel)
|
||||
return
|
||||
// For silent mode, use a no-op logger to suppress all output
|
||||
l := zap.NewNop()
|
||||
mainLog.Store(&ctrld.Logger{Logger: l})
|
||||
return cores
|
||||
case verbose == 1:
|
||||
logLevel = "info"
|
||||
case verbose > 1:
|
||||
logLevel = "debug"
|
||||
}
|
||||
if logLevel == "" {
|
||||
return
|
||||
|
||||
// Parse log level string to zapcore.Level
|
||||
// This provides human-readable log level configuration
|
||||
var level zapcore.Level
|
||||
switch logLevel {
|
||||
case "debug":
|
||||
level = zapcore.DebugLevel
|
||||
case "info":
|
||||
level = zapcore.InfoLevel
|
||||
case "notice":
|
||||
level = ctrld.NoticeLevel
|
||||
case "warn":
|
||||
level = zapcore.WarnLevel
|
||||
case "error":
|
||||
level = zapcore.ErrorLevel
|
||||
default:
|
||||
level = zapcore.InfoLevel // default level
|
||||
}
|
||||
level, err := zerolog.ParseLevel(logLevel)
|
||||
if err != nil {
|
||||
mainLog.Load().Warn().Err(err).Msg("could not set log level")
|
||||
return
|
||||
|
||||
consoleWriter.Enabled(level)
|
||||
// Add cores for all writers
|
||||
// This enables multi-destination logging (console + file)
|
||||
for _, writer := range writers {
|
||||
core := newMachineFriendlyZapCore(writer, level)
|
||||
cores = append(cores, core)
|
||||
}
|
||||
zerolog.SetGlobalLevel(level)
|
||||
|
||||
// Create a multi-core logger
|
||||
// This allows simultaneous logging to multiple destinations
|
||||
multiCore := zapcore.NewTee(cores...)
|
||||
logger := zap.New(multiCore)
|
||||
mainLog.Store(&ctrld.Logger{Logger: logger})
|
||||
|
||||
return cores
|
||||
}
|
||||
|
||||
// initCache initializes DNS cache configuration
|
||||
// This improves performance by caching frequently requested DNS responses
|
||||
func initCache() {
|
||||
if !cfg.Service.CacheEnable {
|
||||
return
|
||||
}
|
||||
if cfg.Service.CacheSize == 0 {
|
||||
// Default cache size provides good balance between memory usage and performance
|
||||
cfg.Service.CacheSize = 4096
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,13 +5,28 @@ import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
"go.uber.org/zap"
|
||||
"go.uber.org/zap/zapcore"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld"
|
||||
)
|
||||
|
||||
var logOutput strings.Builder
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
l := zerolog.New(&logOutput)
|
||||
mainLog.Store(&l)
|
||||
// Create a custom writer that writes to logOutput
|
||||
writer := zapcore.AddSync(&logOutput)
|
||||
|
||||
// Create zap encoder
|
||||
encoderConfig := zap.NewDevelopmentEncoderConfig()
|
||||
encoder := zapcore.NewConsoleEncoder(encoderConfig)
|
||||
|
||||
// Create core that writes to our string builder
|
||||
core := zapcore.NewCore(encoder, writer, zap.DebugLevel)
|
||||
|
||||
// Create logger
|
||||
l := zap.New(core)
|
||||
|
||||
mainLog.Store(&ctrld.Logger{Logger: l})
|
||||
os.Exit(m.Run())
|
||||
}
|
||||
|
||||
@@ -15,6 +15,7 @@ import (
|
||||
)
|
||||
|
||||
// metricsServer represents a server to expose Prometheus metrics via HTTP.
|
||||
// This provides monitoring and observability for the DNS proxy service
|
||||
type metricsServer struct {
|
||||
server *http.Server
|
||||
mux *http.ServeMux
|
||||
@@ -24,6 +25,7 @@ type metricsServer struct {
|
||||
}
|
||||
|
||||
// newMetricsServer returns new metrics server.
|
||||
// This initializes the HTTP server for exposing Prometheus metrics
|
||||
func newMetricsServer(addr string, reg *prometheus.Registry) (*metricsServer, error) {
|
||||
mux := http.NewServeMux()
|
||||
ms := &metricsServer{
|
||||
@@ -37,11 +39,13 @@ func newMetricsServer(addr string, reg *prometheus.Registry) (*metricsServer, er
|
||||
}
|
||||
|
||||
// register adds handlers for given pattern.
|
||||
// This provides a clean interface for adding HTTP endpoints to the metrics server
|
||||
func (ms *metricsServer) register(pattern string, handler http.Handler) {
|
||||
ms.mux.Handle(pattern, handler)
|
||||
}
|
||||
|
||||
// registerMetricsServerHandler adds handlers for metrics server.
|
||||
// This sets up both Prometheus format and JSON format endpoints for metrics
|
||||
func (ms *metricsServer) registerMetricsServerHandler() {
|
||||
ms.register("/metrics", promhttp.HandlerFor(
|
||||
ms.reg,
|
||||
@@ -74,6 +78,7 @@ func (ms *metricsServer) registerMetricsServerHandler() {
|
||||
}
|
||||
|
||||
// start runs the metricsServer.
|
||||
// This starts the HTTP server for metrics exposure
|
||||
func (ms *metricsServer) start() error {
|
||||
listener, err := net.Listen("tcp", ms.addr)
|
||||
if err != nil {
|
||||
@@ -85,6 +90,7 @@ func (ms *metricsServer) start() error {
|
||||
}
|
||||
|
||||
// stop shutdowns the metricsServer within 2 seconds timeout.
|
||||
// This ensures graceful shutdown of the metrics server
|
||||
func (ms *metricsServer) stop() error {
|
||||
if !ms.started {
|
||||
return nil
|
||||
@@ -95,6 +101,7 @@ func (ms *metricsServer) stop() error {
|
||||
}
|
||||
|
||||
// runMetricsServer initializes metrics stats and runs the metrics server if enabled.
|
||||
// This sets up the complete metrics infrastructure including Prometheus collectors
|
||||
func (p *prog) runMetricsServer(ctx context.Context, reloadCh chan struct{}) {
|
||||
if !p.metricsEnabled() {
|
||||
return
|
||||
@@ -115,7 +122,7 @@ func (p *prog) runMetricsServer(ctx context.Context, reloadCh chan struct{}) {
|
||||
addr := p.cfg.Service.MetricsListener
|
||||
ms, err := newMetricsServer(addr, reg)
|
||||
if err != nil {
|
||||
mainLog.Load().Warn().Err(err).Msg("could not create new metrics server")
|
||||
mainLog.Load().Warn().Err(err).Msg("Could not create new metrics server")
|
||||
return
|
||||
}
|
||||
// Only start listener address if defined.
|
||||
@@ -130,9 +137,9 @@ func (p *prog) runMetricsServer(ctx context.Context, reloadCh chan struct{}) {
|
||||
statsVersion.WithLabelValues(commit, runtime.Version(), curVersion()).Inc()
|
||||
reg.MustRegister(statsTimeStart)
|
||||
statsTimeStart.Set(float64(time.Now().Unix()))
|
||||
mainLog.Load().Debug().Msgf("starting metrics server on: %s", addr)
|
||||
mainLog.Load().Debug().Msgf("Starting metrics server on: %s", addr)
|
||||
if err := ms.start(); err != nil {
|
||||
mainLog.Load().Warn().Err(err).Msg("could not start metrics server")
|
||||
mainLog.Load().Warn().Err(err).Msg("Could not start metrics server")
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -144,7 +151,7 @@ func (p *prog) runMetricsServer(ctx context.Context, reloadCh chan struct{}) {
|
||||
}
|
||||
|
||||
if err := ms.stop(); err != nil {
|
||||
mainLog.Load().Warn().Err(err).Msg("could not stop metrics server")
|
||||
mainLog.Load().Warn().Err(err).Msg("Could not stop metrics server")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,34 +0,0 @@
|
||||
package cli
|
||||
|
||||
import "strings"
|
||||
|
||||
// Copied from https://gist.github.com/Ultraporing/fe52981f678be6831f747c206a4861cb
|
||||
|
||||
// Mac Address parts to look for, and identify non-physical devices. There may be more, update me!
|
||||
var macAddrPartsToFilter = []string{
|
||||
"00:03:FF", // Microsoft Hyper-V, Virtual Server, Virtual PC
|
||||
"0A:00:27", // VirtualBox
|
||||
"00:00:00:00:00", // Teredo Tunneling Pseudo-Interface
|
||||
"00:50:56", // VMware ESX 3, Server, Workstation, Player
|
||||
"00:1C:14", // VMware ESX 3, Server, Workstation, Player
|
||||
"00:0C:29", // VMware ESX 3, Server, Workstation, Player
|
||||
"00:05:69", // VMware ESX 3, Server, Workstation, Player
|
||||
"00:1C:42", // Microsoft Hyper-V, Virtual Server, Virtual PC
|
||||
"00:0F:4B", // Virtual Iron 4
|
||||
"00:16:3E", // Red Hat Xen, Oracle VM, XenSource, Novell Xen
|
||||
"08:00:27", // Sun xVM VirtualBox
|
||||
"7A:79", // Hamachi
|
||||
}
|
||||
|
||||
// Filters the possible physical interface address by comparing it to known popular VM Software addresses
|
||||
// and Teredo Tunneling Pseudo-Interface.
|
||||
//
|
||||
//lint:ignore U1000 use in net_windows.go
|
||||
func isPhysicalInterface(addr string) bool {
|
||||
for _, macPart := range macAddrPartsToFilter {
|
||||
if strings.HasPrefix(strings.ToLower(addr), strings.ToLower(macPart)) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
@@ -9,17 +9,18 @@ import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
func patchNetIfaceName(iface *net.Interface) error {
|
||||
func patchNetIfaceName(iface *net.Interface) (bool, error) {
|
||||
b, err := exec.Command("networksetup", "-listnetworkserviceorder").Output()
|
||||
if err != nil {
|
||||
return err
|
||||
return false, err
|
||||
}
|
||||
|
||||
patched := false
|
||||
if name := networkServiceName(iface.Name, bytes.NewReader(b)); name != "" {
|
||||
patched = true
|
||||
iface.Name = name
|
||||
mainLog.Load().Debug().Str("network_service", name).Msg("found network service name for interface")
|
||||
}
|
||||
return nil
|
||||
return patched, nil
|
||||
}
|
||||
|
||||
func networkServiceName(ifaceName string, r io.Reader) string {
|
||||
@@ -48,27 +49,3 @@ func validInterface(iface *net.Interface, validIfacesMap map[string]struct{}) bo
|
||||
_, ok := validIfacesMap[iface.Name]
|
||||
return ok
|
||||
}
|
||||
|
||||
func validInterfacesMap() map[string]struct{} {
|
||||
b, err := exec.Command("networksetup", "-listallhardwareports").Output()
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
return parseListAllHardwarePorts(bytes.NewReader(b))
|
||||
}
|
||||
|
||||
// parseListAllHardwarePorts parses output of "networksetup -listallhardwareports"
|
||||
// and returns map presents all hardware ports.
|
||||
func parseListAllHardwarePorts(r io.Reader) map[string]struct{} {
|
||||
m := make(map[string]struct{})
|
||||
scanner := bufio.NewScanner(r)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
after, ok := strings.CutPrefix(line, "Device: ")
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
m[after] = struct{}{}
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
17
cmd/cli/net_linux.go
Normal file
17
cmd/cli/net_linux.go
Normal file
@@ -0,0 +1,17 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"net"
|
||||
)
|
||||
|
||||
// patchNetIfaceName patches network interface names on Linux
|
||||
// This is a no-op on Linux as interface names don't need special handling
|
||||
func patchNetIfaceName(iface *net.Interface) (bool, error) { return true, nil }
|
||||
|
||||
// validInterface reports whether the *net.Interface is a valid one.
|
||||
// Only non-virtual interfaces are considered valid.
|
||||
// This prevents DNS configuration on virtual interfaces like docker, veth, etc.
|
||||
func validInterface(iface *net.Interface, validIfacesMap map[string]struct{}) bool {
|
||||
_, ok := validIfacesMap[iface.Name]
|
||||
return ok
|
||||
}
|
||||
@@ -1,11 +1,13 @@
|
||||
//go:build !darwin && !windows
|
||||
//go:build !darwin && !windows && !linux
|
||||
|
||||
package cli
|
||||
|
||||
import "net"
|
||||
import (
|
||||
"net"
|
||||
)
|
||||
|
||||
func patchNetIfaceName(iface *net.Interface) error { return nil }
|
||||
// patchNetIfaceName patches network interface names on non-Linux/Darwin platforms
|
||||
func patchNetIfaceName(iface *net.Interface) (bool, error) { return true, nil }
|
||||
|
||||
// validInterface checks if an interface is valid on non-Linux/Darwin platforms
|
||||
func validInterface(iface *net.Interface, validIfacesMap map[string]struct{}) bool { return true }
|
||||
|
||||
func validInterfacesMap() map[string]struct{} { return nil }
|
||||
|
||||
@@ -4,20 +4,13 @@ import (
|
||||
"net"
|
||||
)
|
||||
|
||||
func patchNetIfaceName(iface *net.Interface) error {
|
||||
return nil
|
||||
func patchNetIfaceName(iface *net.Interface) (bool, error) {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// validInterface reports whether the *net.Interface is a valid one.
|
||||
// On Windows, only physical interfaces are considered valid.
|
||||
func validInterface(iface *net.Interface, validIfacesMap map[string]struct{}) bool {
|
||||
if iface == nil {
|
||||
return false
|
||||
}
|
||||
if isPhysicalInterface(iface.HardwareAddr.String()) {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
_, ok := validIfacesMap[iface.Name]
|
||||
return ok
|
||||
}
|
||||
|
||||
func validInterfacesMap() map[string]struct{} { return nil }
|
||||
|
||||
47
cmd/cli/net_windows_test.go
Normal file
47
cmd/cli/net_windows_test.go
Normal file
@@ -0,0 +1,47 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"maps"
|
||||
"slices"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld"
|
||||
)
|
||||
|
||||
func Test_validInterfaces(t *testing.T) {
|
||||
verbose = 3
|
||||
initConsoleLogging()
|
||||
start := time.Now()
|
||||
im := ctrld.ValidInterfaces(ctrld.LoggerCtx(context.Background(), mainLog.Load()))
|
||||
t.Logf("Using Windows API takes: %d", time.Since(start).Milliseconds())
|
||||
ifaces := slices.Collect(maps.Keys(im))
|
||||
|
||||
start = time.Now()
|
||||
ifacesPowershell := validInterfacesPowershell()
|
||||
t.Logf("Using Powershell takes: %d", time.Since(start).Milliseconds())
|
||||
|
||||
slices.Sort(ifaces)
|
||||
slices.Sort(ifacesPowershell)
|
||||
if !slices.Equal(ifaces, ifacesPowershell) {
|
||||
t.Fatalf("result mismatch, want: %v, got: %v", ifacesPowershell, ifaces)
|
||||
}
|
||||
}
|
||||
|
||||
func validInterfacesPowershell() []string {
|
||||
out, err := powershell("Get-NetAdapter -Physical | Select-Object -ExpandProperty Name")
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
var res []string
|
||||
scanner := bufio.NewScanner(bytes.NewReader(out))
|
||||
for scanner.Scan() {
|
||||
ifaceName := strings.TrimSpace(scanner.Text())
|
||||
res = append(res, ifaceName)
|
||||
}
|
||||
return res
|
||||
}
|
||||
@@ -5,6 +5,8 @@ import (
|
||||
|
||||
"github.com/vishvananda/netlink"
|
||||
"golang.org/x/sys/unix"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld"
|
||||
)
|
||||
|
||||
func (p *prog) watchLinkState(ctx context.Context) {
|
||||
@@ -12,7 +14,7 @@ func (p *prog) watchLinkState(ctx context.Context) {
|
||||
done := make(chan struct{})
|
||||
defer close(done)
|
||||
if err := netlink.LinkSubscribe(ch, done); err != nil {
|
||||
mainLog.Load().Warn().Err(err).Msg("could not subscribe link")
|
||||
p.Warn().Err(err).Msg("Could not subscribe link")
|
||||
return
|
||||
}
|
||||
for {
|
||||
@@ -24,9 +26,9 @@ func (p *prog) watchLinkState(ctx context.Context) {
|
||||
continue
|
||||
}
|
||||
if lu.Change&unix.IFF_UP != 0 {
|
||||
mainLog.Load().Debug().Msgf("link state changed, re-bootstrapping")
|
||||
p.Debug().Msgf("Link state changed, re-bootstrapping")
|
||||
for _, uc := range p.cfg.Upstream {
|
||||
uc.ReBootstrap()
|
||||
uc.ReBootstrap(ctrld.LoggerCtx(ctx, p.logger.Load()))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -23,66 +23,67 @@ systemd-resolved=false
|
||||
var networkManagerCtrldConfFile = filepath.Join(nmConfDir, nmCtrldConfFilename)
|
||||
|
||||
// hasNetworkManager reports whether NetworkManager executable found.
|
||||
// hasNetworkManager checks if NetworkManager is available on the system
|
||||
func hasNetworkManager() bool {
|
||||
exe, _ := exec.LookPath("NetworkManager")
|
||||
return exe != ""
|
||||
}
|
||||
|
||||
func setupNetworkManager() error {
|
||||
func (p *prog) setupNetworkManager() error {
|
||||
if !hasNetworkManager() {
|
||||
return nil
|
||||
}
|
||||
if content, _ := os.ReadFile(nmCtrldConfContent); string(content) == nmCtrldConfContent {
|
||||
mainLog.Load().Debug().Msg("NetworkManager already setup, nothing to do")
|
||||
p.Debug().Msg("NetworkManager already setup, nothing to do")
|
||||
return nil
|
||||
}
|
||||
err := os.WriteFile(networkManagerCtrldConfFile, []byte(nmCtrldConfContent), os.FileMode(0644))
|
||||
if os.IsNotExist(err) {
|
||||
mainLog.Load().Debug().Msg("NetworkManager is not available")
|
||||
p.Debug().Msg("NetworkManager is not available")
|
||||
return nil
|
||||
}
|
||||
if err != nil {
|
||||
mainLog.Load().Debug().Err(err).Msg("could not write NetworkManager ctrld config file")
|
||||
p.Debug().Err(err).Msg("Could not write NetworkManager ctrld config file")
|
||||
return err
|
||||
}
|
||||
|
||||
reloadNetworkManager()
|
||||
mainLog.Load().Debug().Msg("setup NetworkManager done")
|
||||
p.reloadNetworkManager()
|
||||
p.Debug().Msg("Setup NetworkManager done")
|
||||
return nil
|
||||
}
|
||||
|
||||
func restoreNetworkManager() error {
|
||||
func (p *prog) restoreNetworkManager() error {
|
||||
if !hasNetworkManager() {
|
||||
return nil
|
||||
}
|
||||
err := os.Remove(networkManagerCtrldConfFile)
|
||||
if os.IsNotExist(err) {
|
||||
mainLog.Load().Debug().Msg("NetworkManager is not available")
|
||||
p.Debug().Msg("NetworkManager is not available")
|
||||
return nil
|
||||
}
|
||||
if err != nil {
|
||||
mainLog.Load().Debug().Err(err).Msg("could not remove NetworkManager ctrld config file")
|
||||
p.Debug().Err(err).Msg("Could not remove NetworkManager ctrld config file")
|
||||
return err
|
||||
}
|
||||
|
||||
reloadNetworkManager()
|
||||
mainLog.Load().Debug().Msg("restore NetworkManager done")
|
||||
p.reloadNetworkManager()
|
||||
p.Debug().Msg("Restore NetworkManager done")
|
||||
return nil
|
||||
}
|
||||
|
||||
func reloadNetworkManager() {
|
||||
func (p *prog) reloadNetworkManager() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
|
||||
defer cancel()
|
||||
conn, err := dbus.NewSystemConnectionContext(ctx)
|
||||
if err != nil {
|
||||
mainLog.Load().Error().Err(err).Msg("could not create new system connection")
|
||||
p.Error().Err(err).Msg("Could not create new system connection")
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
waitCh := make(chan string)
|
||||
if _, err := conn.ReloadUnitContext(ctx, nmSystemdUnitName, "ignore-dependencies", waitCh); err != nil {
|
||||
mainLog.Load().Debug().Err(err).Msg("could not reload NetworkManager")
|
||||
p.Debug().Err(err).Msg("Could not reload NetworkManager")
|
||||
return
|
||||
}
|
||||
<-waitCh
|
||||
|
||||
@@ -2,14 +2,14 @@
|
||||
|
||||
package cli
|
||||
|
||||
func setupNetworkManager() error {
|
||||
reloadNetworkManager()
|
||||
func (p *prog) setupNetworkManager() error {
|
||||
p.reloadNetworkManager()
|
||||
return nil
|
||||
}
|
||||
|
||||
func restoreNetworkManager() error {
|
||||
reloadNetworkManager()
|
||||
func (p *prog) restoreNetworkManager() error {
|
||||
p.reloadNetworkManager()
|
||||
return nil
|
||||
}
|
||||
|
||||
func reloadNetworkManager() {}
|
||||
func (p *prog) reloadNetworkManager() {}
|
||||
|
||||
@@ -8,11 +8,12 @@ import (
|
||||
|
||||
const nextdnsURL = "https://dns.nextdns.io"
|
||||
|
||||
// generateNextDNSConfig generates NextDNS configuration for the given UID
|
||||
func generateNextDNSConfig(uid string) {
|
||||
if uid == "" {
|
||||
return
|
||||
}
|
||||
mainLog.Load().Info().Msg("generating ctrld config for NextDNS resolver")
|
||||
mainLog.Load().Info().Msg("Generating ctrld config for NextDNS resolver")
|
||||
cfg = ctrld.Config{
|
||||
Listener: map[string]*ctrld.ListenerConfig{
|
||||
"0": {
|
||||
|
||||
5
cmd/cli/nocgo.go
Normal file
5
cmd/cli/nocgo.go
Normal file
@@ -0,0 +1,5 @@
|
||||
//go:build !cgo
|
||||
|
||||
package cli
|
||||
|
||||
const cgoEnabled = false
|
||||
@@ -8,26 +8,31 @@ import (
|
||||
"os/exec"
|
||||
"strings"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld/internal/resolvconffile"
|
||||
"github.com/Control-D-Inc/ctrld"
|
||||
)
|
||||
|
||||
// allocate loopback ip
|
||||
// allocateIP allocates an IP address on the specified interface
|
||||
// sudo ifconfig lo0 alias 127.0.0.2 up
|
||||
func allocateIP(ip string) error {
|
||||
mainLog.Load().Debug().Str("ip", ip).Msg("Allocating IP address")
|
||||
cmd := exec.Command("ifconfig", "lo0", "alias", ip, "up")
|
||||
if err := cmd.Run(); err != nil {
|
||||
mainLog.Load().Error().Err(err).Msg("allocateIP failed")
|
||||
mainLog.Load().Error().Err(err).Msg("AllocateIP failed")
|
||||
return err
|
||||
}
|
||||
mainLog.Load().Debug().Str("ip", ip).Msg("IP address allocated successfully")
|
||||
return nil
|
||||
}
|
||||
|
||||
// deAllocateIP deallocates an IP address from the specified interface
|
||||
func deAllocateIP(ip string) error {
|
||||
mainLog.Load().Debug().Str("ip", ip).Msg("Deallocating IP address")
|
||||
cmd := exec.Command("ifconfig", "lo0", "-alias", ip)
|
||||
if err := cmd.Run(); err != nil {
|
||||
mainLog.Load().Error().Err(err).Msg("deAllocateIP failed")
|
||||
mainLog.Load().Error().Err(err).Msg("DeAllocateIP failed")
|
||||
return err
|
||||
}
|
||||
mainLog.Load().Debug().Str("ip", ip).Msg("IP address deallocated successfully")
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -47,12 +52,19 @@ func setDnsIgnoreUnusableInterface(iface *net.Interface, nameservers []string) e
|
||||
// networksetup -setdnsservers Wi-Fi 8.8.8.8 1.1.1.1
|
||||
// TODO(cuonglm): use system API
|
||||
func setDNS(iface *net.Interface, nameservers []string) error {
|
||||
mainLog.Load().Debug().Str("interface", iface.Name).Strs("nameservers", nameservers).Msg("Setting DNS configuration")
|
||||
|
||||
// Note that networksetup won't modify search domains settings,
|
||||
// This assignment is just a placeholder to silent linter.
|
||||
_ = searchDomains
|
||||
cmd := "networksetup"
|
||||
args := []string{"-setdnsservers", iface.Name}
|
||||
args = append(args, nameservers...)
|
||||
if out, err := exec.Command(cmd, args...).CombinedOutput(); err != nil {
|
||||
return fmt.Errorf("%v: %w", string(out), err)
|
||||
}
|
||||
|
||||
mainLog.Load().Debug().Str("interface", iface.Name).Msg("DNS configuration set successfully")
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -70,21 +82,30 @@ func resetDnsIgnoreUnusableInterface(iface *net.Interface) error {
|
||||
|
||||
// TODO(cuonglm): use system API
|
||||
func resetDNS(iface *net.Interface) error {
|
||||
if ns := savedStaticNameservers(iface); len(ns) > 0 {
|
||||
if err := setDNS(iface, ns); err == nil {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
mainLog.Load().Debug().Str("interface", iface.Name).Msg("Resetting DNS configuration")
|
||||
|
||||
cmd := "networksetup"
|
||||
args := []string{"-setdnsservers", iface.Name, "empty"}
|
||||
if out, err := exec.Command(cmd, args...).CombinedOutput(); err != nil {
|
||||
return fmt.Errorf("%v: %w", string(out), err)
|
||||
}
|
||||
|
||||
mainLog.Load().Debug().Str("interface", iface.Name).Msg("DNS configuration reset successfully")
|
||||
return nil
|
||||
}
|
||||
|
||||
// restoreDNS restores the DNS settings of the given interface.
|
||||
// this should only be executed upon turning off the ctrld service.
|
||||
func restoreDNS(iface *net.Interface) (err error) {
|
||||
if ns := ctrld.SavedStaticNameservers(iface); len(ns) > 0 {
|
||||
err = setDNS(iface, ns)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// currentDNS returns the current DNS servers for the specified interface
|
||||
func currentDNS(_ *net.Interface) []string {
|
||||
return resolvconffile.NameServers("")
|
||||
return ctrld.CurrentNameserversFromResolvconf()
|
||||
}
|
||||
|
||||
// currentStaticDNS returns the current static DNS settings of given interface.
|
||||
|
||||
@@ -5,27 +5,36 @@ import (
|
||||
"net/netip"
|
||||
"os/exec"
|
||||
|
||||
"tailscale.com/control/controlknobs"
|
||||
"tailscale.com/health"
|
||||
"tailscale.com/util/dnsname"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld"
|
||||
"github.com/Control-D-Inc/ctrld/internal/dns"
|
||||
"github.com/Control-D-Inc/ctrld/internal/resolvconffile"
|
||||
)
|
||||
|
||||
// allocate loopback ip
|
||||
// allocateIP allocates an IP address on the specified interface
|
||||
// sudo ifconfig lo0 127.0.0.53 alias
|
||||
func allocateIP(ip string) error {
|
||||
mainLog.Load().Debug().Str("ip", ip).Msg("Allocating IP address")
|
||||
cmd := exec.Command("ifconfig", "lo0", ip, "alias")
|
||||
if err := cmd.Run(); err != nil {
|
||||
mainLog.Load().Error().Err(err).Msg("allocateIP failed")
|
||||
return err
|
||||
}
|
||||
mainLog.Load().Debug().Str("ip", ip).Msg("IP address allocated successfully")
|
||||
return nil
|
||||
}
|
||||
|
||||
// deAllocateIP deallocates an IP address from the specified interface
|
||||
func deAllocateIP(ip string) error {
|
||||
mainLog.Load().Debug().Str("ip", ip).Msg("Deallocating IP address")
|
||||
cmd := exec.Command("ifconfig", "lo0", ip, "-alias")
|
||||
if err := cmd.Run(); err != nil {
|
||||
mainLog.Load().Error().Err(err).Msg("deAllocateIP failed")
|
||||
return err
|
||||
}
|
||||
mainLog.Load().Debug().Str("ip", ip).Msg("IP address deallocated successfully")
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -36,9 +45,11 @@ func setDnsIgnoreUnusableInterface(iface *net.Interface, nameservers []string) e
|
||||
|
||||
// set the dns server for the provided network interface
|
||||
func setDNS(iface *net.Interface, nameservers []string) error {
|
||||
r, err := dns.NewOSConfigurator(logf, iface.Name)
|
||||
mainLog.Load().Debug().Str("interface", iface.Name).Strs("nameservers", nameservers).Msg("Setting DNS configuration")
|
||||
|
||||
r, err := dns.NewOSConfigurator(logf, &health.Tracker{}, &controlknobs.Knobs{}, iface.Name)
|
||||
if err != nil {
|
||||
mainLog.Load().Error().Err(err).Msg("failed to create DNS OS configurator")
|
||||
mainLog.Load().Error().Err(err).Msg("Failed to create DNS OS configurator")
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -47,10 +58,22 @@ func setDNS(iface *net.Interface, nameservers []string) error {
|
||||
ns = append(ns, netip.MustParseAddr(nameserver))
|
||||
}
|
||||
|
||||
if err := r.SetDNS(dns.OSConfig{Nameservers: ns}); err != nil {
|
||||
mainLog.Load().Error().Err(err).Msg("failed to set DNS")
|
||||
osConfig := dns.OSConfig{
|
||||
Nameservers: ns,
|
||||
SearchDomains: []dnsname.FQDN{},
|
||||
}
|
||||
if sds, err := searchDomains(); err == nil {
|
||||
osConfig.SearchDomains = sds
|
||||
} else {
|
||||
mainLog.Load().Debug().Err(err).Msg("Failed to get search domains list")
|
||||
}
|
||||
|
||||
if err := r.SetDNS(osConfig); err != nil {
|
||||
mainLog.Load().Error().Err(err).Msg("Failed to set DNS")
|
||||
return err
|
||||
}
|
||||
|
||||
mainLog.Load().Debug().Str("interface", iface.Name).Msg("DNS configuration set successfully")
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -59,22 +82,34 @@ func resetDnsIgnoreUnusableInterface(iface *net.Interface) error {
|
||||
return resetDNS(iface)
|
||||
}
|
||||
|
||||
// resetDNS resets DNS servers for the specified interface
|
||||
func resetDNS(iface *net.Interface) error {
|
||||
r, err := dns.NewOSConfigurator(logf, iface.Name)
|
||||
mainLog.Load().Debug().Str("interface", iface.Name).Msg("Resetting DNS configuration")
|
||||
|
||||
r, err := dns.NewOSConfigurator(logf, &health.Tracker{}, &controlknobs.Knobs{}, iface.Name)
|
||||
if err != nil {
|
||||
mainLog.Load().Error().Err(err).Msg("failed to create DNS OS configurator")
|
||||
mainLog.Load().Error().Err(err).Msg("Failed to create DNS OS configurator")
|
||||
return err
|
||||
}
|
||||
|
||||
if err := r.Close(); err != nil {
|
||||
mainLog.Load().Error().Err(err).Msg("failed to rollback DNS setting")
|
||||
mainLog.Load().Error().Err(err).Msg("Failed to rollback DNS setting")
|
||||
return err
|
||||
}
|
||||
|
||||
mainLog.Load().Debug().Str("interface", iface.Name).Msg("DNS configuration reset successfully")
|
||||
return nil
|
||||
}
|
||||
|
||||
// restoreDNS restores the DNS settings of the given interface.
|
||||
// this should only be executed upon turning off the ctrld service.
|
||||
func restoreDNS(iface *net.Interface) (err error) {
|
||||
return err
|
||||
}
|
||||
|
||||
// currentDNS returns the current DNS servers for the specified interface
|
||||
func currentDNS(_ *net.Interface) []string {
|
||||
return resolvconffile.NameServers("")
|
||||
return ctrld.CurrentNameserversFromResolvconf()
|
||||
}
|
||||
|
||||
// currentStaticDNS returns the current static DNS settings of given interface.
|
||||
|
||||
@@ -17,32 +17,40 @@ import (
|
||||
"github.com/insomniacslk/dhcp/dhcpv4/nclient4"
|
||||
"github.com/insomniacslk/dhcp/dhcpv6"
|
||||
"github.com/insomniacslk/dhcp/dhcpv6/client6"
|
||||
"tailscale.com/control/controlknobs"
|
||||
"tailscale.com/health"
|
||||
"tailscale.com/util/dnsname"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld"
|
||||
"github.com/Control-D-Inc/ctrld/internal/dns"
|
||||
ctrldnet "github.com/Control-D-Inc/ctrld/internal/net"
|
||||
"github.com/Control-D-Inc/ctrld/internal/resolvconffile"
|
||||
)
|
||||
|
||||
const resolvConfBackupFailedMsg = "open /etc/resolv.pre-ctrld-backup.conf: read-only file system"
|
||||
|
||||
type getDNS func(iface string) []string
|
||||
|
||||
// allocate loopback ip
|
||||
// sudo ip a add 127.0.0.2/24 dev lo
|
||||
func allocateIP(ip string) error {
|
||||
mainLog.Load().Debug().Str("ip", ip).Msg("Allocating IP address")
|
||||
cmd := exec.Command("ip", "a", "add", ip+"/24", "dev", "lo")
|
||||
if out, err := cmd.CombinedOutput(); err != nil {
|
||||
mainLog.Load().Error().Err(err).Msgf("allocateIP failed: %s", string(out))
|
||||
mainLog.Load().Error().Err(err).Msgf("AllocateIP failed: %s", string(out))
|
||||
return err
|
||||
}
|
||||
mainLog.Load().Debug().Str("ip", ip).Msg("IP address allocated successfully")
|
||||
return nil
|
||||
}
|
||||
|
||||
func deAllocateIP(ip string) error {
|
||||
mainLog.Load().Debug().Str("ip", ip).Msg("Deallocating IP address")
|
||||
cmd := exec.Command("ip", "a", "del", ip+"/24", "dev", "lo")
|
||||
if err := cmd.Run(); err != nil {
|
||||
mainLog.Load().Error().Err(err).Msg("deAllocateIP failed")
|
||||
mainLog.Load().Error().Err(err).Msg("DeAllocateIP failed")
|
||||
return err
|
||||
}
|
||||
mainLog.Load().Debug().Str("ip", ip).Msg("IP address deallocated successfully")
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -54,9 +62,11 @@ func setDnsIgnoreUnusableInterface(iface *net.Interface, nameservers []string) e
|
||||
}
|
||||
|
||||
func setDNS(iface *net.Interface, nameservers []string) error {
|
||||
r, err := dns.NewOSConfigurator(logf, iface.Name)
|
||||
mainLog.Load().Debug().Str("interface", iface.Name).Strs("nameservers", nameservers).Msg("Setting DNS configuration")
|
||||
|
||||
r, err := dns.NewOSConfigurator(logf, &health.Tracker{}, &controlknobs.Knobs{}, iface.Name)
|
||||
if err != nil {
|
||||
mainLog.Load().Error().Err(err).Msg("failed to create DNS OS configurator")
|
||||
mainLog.Load().Error().Err(err).Msg("Failed to create dns os configurator")
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -69,35 +79,31 @@ func setDNS(iface *net.Interface, nameservers []string) error {
|
||||
Nameservers: ns,
|
||||
SearchDomains: []dnsname.FQDN{},
|
||||
}
|
||||
if sds, err := searchDomains(); err == nil {
|
||||
osConfig.SearchDomains = sds
|
||||
} else {
|
||||
mainLog.Load().Debug().Err(err).Msg("Failed to get search domains list")
|
||||
}
|
||||
trySystemdResolve := false
|
||||
for i := 0; i < maxSetDNSAttempts; i++ {
|
||||
if err := r.SetDNS(osConfig); err != nil {
|
||||
if strings.Contains(err.Error(), "Rejected send message") &&
|
||||
strings.Contains(err.Error(), "org.freedesktop.network1.Manager") {
|
||||
mainLog.Load().Warn().Msg("Interfaces are managed by systemd-networkd, switch to systemd-resolve for setting DNS")
|
||||
trySystemdResolve = true
|
||||
break
|
||||
}
|
||||
// This error happens on read-only file system, which causes ctrld failed to create backup
|
||||
// for /etc/resolv.conf file. It is ok, because the DNS is still set anyway, and restore
|
||||
// DNS will fallback to use DHCP if there's no backup /etc/resolv.conf file.
|
||||
// The error format is controlled by us, so checking for error string is fine.
|
||||
// See: ../../internal/dns/direct.go:L278
|
||||
if r.Mode() == "direct" && strings.Contains(err.Error(), resolvConfBackupFailedMsg) {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
if err := r.SetDNS(osConfig); err != nil {
|
||||
if strings.Contains(err.Error(), "Rejected send message") &&
|
||||
strings.Contains(err.Error(), "org.freedesktop.network1.Manager") {
|
||||
mainLog.Load().Warn().Msg("Interfaces are managed by systemd-networkd, switch to systemd-resolve for setting DNS")
|
||||
trySystemdResolve = true
|
||||
goto systemdResolve
|
||||
}
|
||||
if useSystemdResolved {
|
||||
if out, err := exec.Command("systemctl", "restart", "systemd-resolved").CombinedOutput(); err != nil {
|
||||
mainLog.Load().Warn().Err(err).Msgf("could not restart systemd-resolved: %s", string(out))
|
||||
}
|
||||
}
|
||||
currentNS := currentDNS(iface)
|
||||
if isSubSet(nameservers, currentNS) {
|
||||
// This error happens on read-only file system, which causes ctrld failed to create backup
|
||||
// for /etc/resolv.conf file. It is ok, because the DNS is still set anyway, and restore
|
||||
// DNS will fallback to use DHCP if there's no backup /etc/resolv.conf file.
|
||||
// The error format is controlled by us, so checking for error string is fine.
|
||||
// See: ../../internal/dns/direct.go:L278
|
||||
if r.Mode() == "direct" && strings.Contains(err.Error(), resolvConfBackupFailedMsg) {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
systemdResolve:
|
||||
if trySystemdResolve {
|
||||
// Stop systemd-networkd and retry setting DNS.
|
||||
if out, err := exec.Command("systemctl", "stop", "systemd-networkd").CombinedOutput(); err != nil {
|
||||
@@ -117,8 +123,10 @@ func setDNS(iface *net.Interface, nameservers []string) error {
|
||||
}
|
||||
time.Sleep(time.Second)
|
||||
}
|
||||
mainLog.Load().Debug().Msg("DNS was not set for some reason")
|
||||
}
|
||||
mainLog.Load().Debug().Msg("DNS was not set for some reason")
|
||||
|
||||
mainLog.Load().Debug().Str("interface", iface.Name).Msg("DNS configuration set successfully")
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -128,6 +136,8 @@ func resetDnsIgnoreUnusableInterface(iface *net.Interface) error {
|
||||
}
|
||||
|
||||
func resetDNS(iface *net.Interface) (err error) {
|
||||
mainLog.Load().Debug().Str("interface", iface.Name).Msg("Resetting DNS configuration")
|
||||
|
||||
defer func() {
|
||||
if err == nil {
|
||||
return
|
||||
@@ -136,10 +146,10 @@ func resetDNS(iface *net.Interface) (err error) {
|
||||
if exe, _ := exec.LookPath("/lib/systemd/systemd-networkd"); exe != "" {
|
||||
_ = exec.Command("systemctl", "start", "systemd-networkd").Run()
|
||||
}
|
||||
if r, oerr := dns.NewOSConfigurator(logf, iface.Name); oerr == nil {
|
||||
if r, oerr := dns.NewOSConfigurator(logf, &health.Tracker{}, &controlknobs.Knobs{}, iface.Name); oerr == nil {
|
||||
_ = r.SetDNS(dns.OSConfig{})
|
||||
if err := r.Close(); err != nil {
|
||||
mainLog.Load().Error().Err(err).Msg("failed to rollback DNS setting")
|
||||
mainLog.Load().Error().Err(err).Msg("Failed to rollback dns setting")
|
||||
return
|
||||
}
|
||||
err = nil
|
||||
@@ -167,17 +177,18 @@ func resetDNS(iface *net.Interface) (err error) {
|
||||
}
|
||||
|
||||
// TODO(cuonglm): handle DHCPv6 properly.
|
||||
mainLog.Load().Debug().Msg("Checking for ipv6 availability")
|
||||
if ctrldnet.IPv6Available(ctx) {
|
||||
c := client6.NewClient()
|
||||
conversation, err := c.Exchange(iface.Name)
|
||||
if err != nil && !errAddrInUse(err) {
|
||||
mainLog.Load().Debug().Err(err).Msg("could not exchange DHCPv6")
|
||||
mainLog.Load().Debug().Err(err).Msg("Could not exchange dhcpv6")
|
||||
}
|
||||
for _, packet := range conversation {
|
||||
if packet.Type() == dhcpv6.MessageTypeReply {
|
||||
msg, err := packet.GetInnerMessage()
|
||||
if err != nil {
|
||||
mainLog.Load().Debug().Err(err).Msg("could not get inner DHCPv6 message")
|
||||
mainLog.Load().Debug().Err(err).Msg("Could not get inner dhcpv6 message")
|
||||
return nil
|
||||
}
|
||||
nameservers := msg.Options.DNS()
|
||||
@@ -186,6 +197,8 @@ func resetDNS(iface *net.Interface) (err error) {
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
mainLog.Load().Debug().Msg("IPv6 is not available")
|
||||
}
|
||||
|
||||
return ignoringEINTR(func() error {
|
||||
@@ -193,8 +206,15 @@ func resetDNS(iface *net.Interface) (err error) {
|
||||
})
|
||||
}
|
||||
|
||||
// restoreDNS restores the DNS settings of the given interface.
|
||||
// this should only be executed upon turning off the ctrld service.
|
||||
func restoreDNS(iface *net.Interface) (err error) {
|
||||
return err
|
||||
}
|
||||
|
||||
func currentDNS(iface *net.Interface) []string {
|
||||
for _, fn := range []getDNS{getDNSByResolvectl, getDNSBySystemdResolved, getDNSByNmcli, resolvconffile.NameServers} {
|
||||
resolvconfFunc := func(_ string) []string { return ctrld.CurrentNameserversFromResolvconf() }
|
||||
for _, fn := range []getDNS{getDNSByResolvectl, getDNSBySystemdResolved, getDNSByNmcli, resolvconfFunc} {
|
||||
if ns := fn(iface.Name); len(ns) > 0 {
|
||||
return ns
|
||||
}
|
||||
|
||||
@@ -2,12 +2,12 @@
|
||||
|
||||
package cli
|
||||
|
||||
// TODO(cuonglm): implement.
|
||||
// allocateIP allocates an IP address on the specified interface
|
||||
func allocateIP(ip string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// TODO(cuonglm): implement.
|
||||
// deAllocateIP deallocates an IP address from the specified interface
|
||||
func deAllocateIP(ip string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -4,25 +4,21 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"net/netip"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"golang.org/x/sys/windows"
|
||||
"golang.org/x/sys/windows/registry"
|
||||
"golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld"
|
||||
ctrldnet "github.com/Control-D-Inc/ctrld/internal/net"
|
||||
)
|
||||
|
||||
const (
|
||||
v4InterfaceKeyPathFormat = `HKLM:\SYSTEM\CurrentControlSet\Services\Tcpip\Parameters\Interfaces\`
|
||||
v6InterfaceKeyPathFormat = `HKLM:\SYSTEM\CurrentControlSet\Services\Tcpip6\Parameters\Interfaces\`
|
||||
)
|
||||
|
||||
var (
|
||||
setDNSOnce sync.Once
|
||||
resetDNSOnce sync.Once
|
||||
v4InterfaceKeyPathFormat = `SYSTEM\CurrentControlSet\Services\Tcpip\Parameters\Interfaces\`
|
||||
v6InterfaceKeyPathFormat = `SYSTEM\CurrentControlSet\Services\Tcpip6\Parameters\Interfaces\`
|
||||
)
|
||||
|
||||
// setDnsIgnoreUnusableInterface likes setDNS, but return a nil error if the interface is not usable.
|
||||
@@ -30,44 +26,46 @@ func setDnsIgnoreUnusableInterface(iface *net.Interface, nameservers []string) e
|
||||
return setDNS(iface, nameservers)
|
||||
}
|
||||
|
||||
func setDnsPowershellCmd(iface *net.Interface, nameservers []string) string {
|
||||
nss := make([]string, 0, len(nameservers))
|
||||
for _, ns := range nameservers {
|
||||
nss = append(nss, strconv.Quote(ns))
|
||||
}
|
||||
return fmt.Sprintf("Set-DnsClientServerAddress -InterfaceIndex %d -ServerAddresses (%s)", iface.Index, strings.Join(nss, ","))
|
||||
}
|
||||
|
||||
// setDNS sets the dns server for the provided network interface
|
||||
func setDNS(iface *net.Interface, nameservers []string) error {
|
||||
if len(nameservers) == 0 {
|
||||
return errors.New("empty DNS nameservers")
|
||||
}
|
||||
setDNSOnce.Do(func() {
|
||||
// If there's a Dns server running, that means we are on AD with Dns feature enabled.
|
||||
// Configuring the Dns server to forward queries to ctrld instead.
|
||||
if windowsHasLocalDnsServerRunning() {
|
||||
file := absHomeDir(windowsForwardersFilename)
|
||||
oldForwardersContent, _ := os.ReadFile(file)
|
||||
hasLocalIPv6Listener := needLocalIPv6Listener()
|
||||
forwarders := slices.DeleteFunc(slices.Clone(nameservers), func(s string) bool {
|
||||
if !hasLocalIPv6Listener {
|
||||
return false
|
||||
}
|
||||
return s == "::1"
|
||||
})
|
||||
if err := os.WriteFile(file, []byte(strings.Join(forwarders, ",")), 0600); err != nil {
|
||||
mainLog.Load().Warn().Err(err).Msg("could not save forwarders settings")
|
||||
}
|
||||
oldForwarders := strings.Split(string(oldForwardersContent), ",")
|
||||
if err := addDnsServerForwarders(forwarders, oldForwarders); err != nil {
|
||||
mainLog.Load().Warn().Err(err).Msg("could not set forwarders settings")
|
||||
|
||||
luid, err := winipcfg.LUIDFromIndex(uint32(iface.Index))
|
||||
if err != nil {
|
||||
return fmt.Errorf("setDNS: %w", err)
|
||||
}
|
||||
var (
|
||||
serversV4 []netip.Addr
|
||||
serversV6 []netip.Addr
|
||||
)
|
||||
for _, ns := range nameservers {
|
||||
if addr, err := netip.ParseAddr(ns); err == nil {
|
||||
if addr.Is4() {
|
||||
serversV4 = append(serversV4, addr)
|
||||
} else {
|
||||
serversV6 = append(serversV6, addr)
|
||||
}
|
||||
}
|
||||
})
|
||||
out, err := powershell(setDnsPowershellCmd(iface, nameservers))
|
||||
if err != nil {
|
||||
return fmt.Errorf("%w: %s", err, string(out))
|
||||
}
|
||||
|
||||
// Note that Windows won't modify the current search domains if passing nil to luid.SetDNS function.
|
||||
// searchDomains is still implemented for Windows just in case Windows API changes in future versions.
|
||||
_ = searchDomains
|
||||
|
||||
if len(serversV4) == 0 && len(serversV6) == 0 {
|
||||
return errors.New("invalid DNS nameservers")
|
||||
}
|
||||
if len(serversV4) > 0 {
|
||||
if err := luid.SetDNS(windows.AF_INET, serversV4, nil); err != nil {
|
||||
return fmt.Errorf("could not set DNS ipv4: %w", err)
|
||||
}
|
||||
}
|
||||
if len(serversV6) > 0 {
|
||||
if err := luid.SetDNS(windows.AF_INET6, serversV6, nil); err != nil {
|
||||
return fmt.Errorf("could not set DNS ipv6: %w", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -77,34 +75,26 @@ func resetDnsIgnoreUnusableInterface(iface *net.Interface) error {
|
||||
return resetDNS(iface)
|
||||
}
|
||||
|
||||
// TODO(cuonglm): should we use system API?
|
||||
// resetDNS resets DNS servers for the specified interface
|
||||
func resetDNS(iface *net.Interface) error {
|
||||
resetDNSOnce.Do(func() {
|
||||
// See corresponding comment in setDNS.
|
||||
if windowsHasLocalDnsServerRunning() {
|
||||
file := absHomeDir(windowsForwardersFilename)
|
||||
content, err := os.ReadFile(file)
|
||||
if err != nil {
|
||||
mainLog.Load().Error().Err(err).Msg("could not read forwarders settings")
|
||||
return
|
||||
}
|
||||
nameservers := strings.Split(string(content), ",")
|
||||
if err := removeDnsServerForwarders(nameservers); err != nil {
|
||||
mainLog.Load().Error().Err(err).Msg("could not remove forwarders settings")
|
||||
return
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
// Restoring DHCP settings.
|
||||
cmd := fmt.Sprintf("Set-DnsClientServerAddress -InterfaceIndex %d -ResetServerAddresses", iface.Index)
|
||||
out, err := powershell(cmd)
|
||||
luid, err := winipcfg.LUIDFromIndex(uint32(iface.Index))
|
||||
if err != nil {
|
||||
return fmt.Errorf("%w: %s", err, string(out))
|
||||
return fmt.Errorf("resetDNS: %w", err)
|
||||
}
|
||||
// Restoring DHCP settings.
|
||||
if err := luid.SetDNS(windows.AF_INET, nil, nil); err != nil {
|
||||
return fmt.Errorf("could not reset DNS ipv4: %w", err)
|
||||
}
|
||||
if err := luid.SetDNS(windows.AF_INET6, nil, nil); err != nil {
|
||||
return fmt.Errorf("could not reset DNS ipv6: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// If there's static DNS saved, restoring it.
|
||||
if nss := savedStaticNameservers(iface); len(nss) > 0 {
|
||||
// restoreDNS restores the DNS settings of the given interface.
|
||||
// this should only be executed upon turning off the ctrld service.
|
||||
func restoreDNS(iface *net.Interface) (err error) {
|
||||
if nss := ctrld.SavedStaticNameservers(iface); len(nss) > 0 {
|
||||
v4ns := make([]string, 0, 2)
|
||||
v6ns := make([]string, 0, 2)
|
||||
for _, ns := range nss {
|
||||
@@ -115,27 +105,48 @@ func resetDNS(iface *net.Interface) error {
|
||||
}
|
||||
}
|
||||
|
||||
for _, ns := range [][]string{v4ns, v6ns} {
|
||||
if len(ns) == 0 {
|
||||
continue
|
||||
luid, err := winipcfg.LUIDFromIndex(uint32(iface.Index))
|
||||
if err != nil {
|
||||
return fmt.Errorf("restoreDNS: %w", err)
|
||||
}
|
||||
|
||||
if len(v4ns) > 0 {
|
||||
mainLog.Load().Debug().Msgf("Restoring IPv4 static DNS for interface %q: %v", iface.Name, v4ns)
|
||||
if err := setDNS(iface, v4ns); err != nil {
|
||||
return fmt.Errorf("restoreDNS (IPv4): %w", err)
|
||||
}
|
||||
if err := setDNS(iface, ns); err != nil {
|
||||
return err
|
||||
} else {
|
||||
mainLog.Load().Debug().Msgf("Restoring IPv4 DHCP for interface %q", iface.Name)
|
||||
if err := luid.SetDNS(windows.AF_INET, nil, nil); err != nil {
|
||||
return fmt.Errorf("restoreDNS (IPv4 clear): %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if len(v6ns) > 0 {
|
||||
mainLog.Load().Debug().Msgf("Restoring IPv6 static DNS for interface %q: %v", iface.Name, v6ns)
|
||||
if err := setDNS(iface, v6ns); err != nil {
|
||||
return fmt.Errorf("restoreDNS (IPv6): %w", err)
|
||||
}
|
||||
} else {
|
||||
mainLog.Load().Debug().Msgf("Restoring IPv6 DHCP for interface %q", iface.Name)
|
||||
if err := luid.SetDNS(windows.AF_INET6, nil, nil); err != nil {
|
||||
return fmt.Errorf("restoreDNS (IPv6 clear): %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
return err
|
||||
}
|
||||
|
||||
// currentDNS returns the current DNS servers for the specified interface
|
||||
func currentDNS(iface *net.Interface) []string {
|
||||
luid, err := winipcfg.LUIDFromIndex(uint32(iface.Index))
|
||||
if err != nil {
|
||||
mainLog.Load().Error().Err(err).Msg("failed to get interface LUID")
|
||||
mainLog.Load().Error().Err(err).Msg("Failed to get interface LUID")
|
||||
return nil
|
||||
}
|
||||
nameservers, err := luid.DNS()
|
||||
if err != nil {
|
||||
mainLog.Load().Error().Err(err).Msg("failed to get interface DNS")
|
||||
mainLog.Load().Error().Err(err).Msg("Failed to get interface DNS")
|
||||
return nil
|
||||
}
|
||||
ns := make([]string, 0, len(nameservers))
|
||||
@@ -145,73 +156,65 @@ func currentDNS(iface *net.Interface) []string {
|
||||
return ns
|
||||
}
|
||||
|
||||
// currentStaticDNS returns the current static DNS settings of given interface.
|
||||
// currentStaticDNS checks both the IPv4 and IPv6 paths for static DNS values using keys
|
||||
// like "NameServer" and "ProfileNameServer".
|
||||
func currentStaticDNS(iface *net.Interface) ([]string, error) {
|
||||
luid, err := winipcfg.LUIDFromIndex(uint32(iface.Index))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("fallback winipcfg.LUIDFromIndex: %w", err)
|
||||
}
|
||||
guid, err := luid.GUID()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("fallback luid.GUID: %w", err)
|
||||
}
|
||||
|
||||
var ns []string
|
||||
for _, path := range []string{v4InterfaceKeyPathFormat, v6InterfaceKeyPathFormat} {
|
||||
keyPaths := []string{v4InterfaceKeyPathFormat, v6InterfaceKeyPathFormat}
|
||||
for _, path := range keyPaths {
|
||||
interfaceKeyPath := path + guid.String()
|
||||
found := false
|
||||
for _, key := range []string{"NameServer", "ProfileNameServer"} {
|
||||
if found {
|
||||
continue
|
||||
}
|
||||
cmd := fmt.Sprintf(`Get-ItemPropertyValue -Path "%s" -Name "%s"`, interfaceKeyPath, key)
|
||||
out, err := powershell(cmd)
|
||||
if err == nil && len(out) > 0 {
|
||||
found = true
|
||||
for _, e := range strings.Split(string(out), ",") {
|
||||
ns = append(ns, strings.TrimRight(e, "\x00"))
|
||||
k, err := registry.OpenKey(registry.LOCAL_MACHINE, interfaceKeyPath, registry.QUERY_VALUE)
|
||||
if err != nil {
|
||||
mainLog.Load().Debug().Err(err).Msgf("Failed to open registry key %q for interface %q; trying next key", interfaceKeyPath, iface.Name)
|
||||
continue
|
||||
}
|
||||
func() {
|
||||
defer k.Close()
|
||||
for _, keyName := range []string{"NameServer", "ProfileNameServer"} {
|
||||
value, _, err := k.GetStringValue(keyName)
|
||||
if err != nil && !errors.Is(err, registry.ErrNotExist) {
|
||||
mainLog.Load().Debug().Err(err).Msgf("Error reading %s registry key", keyName)
|
||||
continue
|
||||
}
|
||||
if len(value) > 0 {
|
||||
mainLog.Load().Debug().Msgf("Found static DNS for interface %q: %s", iface.Name, value)
|
||||
parsed := parseDNSServers(value)
|
||||
for _, pns := range parsed {
|
||||
if !slices.Contains(ns, pns) {
|
||||
ns = append(ns, pns)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
if len(ns) == 0 {
|
||||
mainLog.Load().Debug().Msgf("No static DNS values found for interface %q", iface.Name)
|
||||
}
|
||||
return ns, nil
|
||||
}
|
||||
|
||||
// addDnsServerForwarders adds given nameservers to DNS server forwarders list,
|
||||
// and also removing old forwarders if provided.
|
||||
func addDnsServerForwarders(nameservers, old []string) error {
|
||||
newForwardersMap := make(map[string]struct{})
|
||||
newForwarders := make([]string, len(nameservers))
|
||||
for i := range nameservers {
|
||||
newForwardersMap[nameservers[i]] = struct{}{}
|
||||
newForwarders[i] = fmt.Sprintf("%q", nameservers[i])
|
||||
}
|
||||
oldForwarders := old[:0]
|
||||
for _, fwd := range old {
|
||||
if _, ok := newForwardersMap[fwd]; !ok {
|
||||
oldForwarders = append(oldForwarders, fwd)
|
||||
// parseDNSServers splits a DNS server string that may be comma- or space-separated,
|
||||
// and trims any extraneous whitespace or null characters.
|
||||
func parseDNSServers(val string) []string {
|
||||
fields := strings.FieldsFunc(val, func(r rune) bool {
|
||||
return r == ' ' || r == ','
|
||||
})
|
||||
var servers []string
|
||||
for _, f := range fields {
|
||||
trimmed := strings.TrimSpace(f)
|
||||
if len(trimmed) > 0 {
|
||||
servers = append(servers, trimmed)
|
||||
}
|
||||
}
|
||||
// NOTE: It is important to add new forwarder before removing old one.
|
||||
// Testing on Windows Server 2022 shows that removing forwarder1
|
||||
// then adding forwarder2 sometimes ends up adding both of them
|
||||
// to the forwarders list.
|
||||
cmd := fmt.Sprintf("Add-DnsServerForwarder -IPAddress %s", strings.Join(newForwarders, ","))
|
||||
if len(oldForwarders) > 0 {
|
||||
cmd = fmt.Sprintf("%s ; Remove-DnsServerForwarder -IPAddress %s -Force", cmd, strings.Join(oldForwarders, ","))
|
||||
}
|
||||
if out, err := powershell(cmd); err != nil {
|
||||
return fmt.Errorf("%w: %s", err, string(out))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// removeDnsServerForwarders removes given nameservers from DNS server forwarders list.
|
||||
func removeDnsServerForwarders(nameservers []string) error {
|
||||
for _, ns := range nameservers {
|
||||
cmd := fmt.Sprintf("Remove-DnsServerForwarder -IPAddress %s -Force", ns)
|
||||
if out, err := powershell(cmd); err != nil {
|
||||
return fmt.Errorf("%w: %s", err, string(out))
|
||||
}
|
||||
}
|
||||
return nil
|
||||
return servers
|
||||
}
|
||||
|
||||
76
cmd/cli/os_windows_test.go
Normal file
76
cmd/cli/os_windows_test.go
Normal file
@@ -0,0 +1,76 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"net"
|
||||
"os/exec"
|
||||
"slices"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
|
||||
)
|
||||
|
||||
func Test_currentStaticDNS(t *testing.T) {
|
||||
iface, err := net.InterfaceByName(defaultIfaceName())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
start := time.Now()
|
||||
staticDns, err := currentStaticDNS(iface)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Logf("Using Windows API takes: %d", time.Since(start).Milliseconds())
|
||||
|
||||
start = time.Now()
|
||||
staticDnsPowershell, err := currentStaticDnsPowershell(iface)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Logf("Using Powershell takes: %d", time.Since(start).Milliseconds())
|
||||
|
||||
slices.Sort(staticDns)
|
||||
slices.Sort(staticDnsPowershell)
|
||||
if !slices.Equal(staticDns, staticDnsPowershell) {
|
||||
t.Fatalf("result mismatch, want: %v, got: %v", staticDnsPowershell, staticDns)
|
||||
}
|
||||
}
|
||||
|
||||
func currentStaticDnsPowershell(iface *net.Interface) ([]string, error) {
|
||||
luid, err := winipcfg.LUIDFromIndex(uint32(iface.Index))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
guid, err := luid.GUID()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var ns []string
|
||||
for _, path := range []string{"HKLM:\\" + v4InterfaceKeyPathFormat, "HKLM:\\" + v6InterfaceKeyPathFormat} {
|
||||
interfaceKeyPath := path + guid.String()
|
||||
found := false
|
||||
for _, key := range []string{"NameServer", "ProfileNameServer"} {
|
||||
if found {
|
||||
continue
|
||||
}
|
||||
cmd := fmt.Sprintf(`Get-ItemPropertyValue -Path "%s" -Name "%s"`, interfaceKeyPath, key)
|
||||
out, err := powershell(cmd)
|
||||
if err == nil && len(out) > 0 {
|
||||
found = true
|
||||
for _, e := range strings.Split(string(out), ",") {
|
||||
ns = append(ns, strings.TrimRight(e, "\x00"))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return ns, nil
|
||||
}
|
||||
|
||||
// powershell runs the given powershell command.
|
||||
func powershell(cmd string) ([]byte, error) {
|
||||
out, err := exec.Command("powershell", "-Command", cmd).CombinedOutput()
|
||||
return bytes.TrimSpace(out), err
|
||||
}
|
||||
1053
cmd/cli/prog.go
1053
cmd/cli/prog.go
File diff suppressed because it is too large
Load Diff
@@ -4,8 +4,10 @@ import (
|
||||
"github.com/kardianos/service"
|
||||
)
|
||||
|
||||
// setDependencies sets service dependencies for Darwin
|
||||
func setDependencies(svc *service.Config) {}
|
||||
|
||||
// setWorkingDirectory sets the working directory for the service
|
||||
func setWorkingDirectory(svc *service.Config, dir string) {
|
||||
svc.WorkingDirectory = dir
|
||||
}
|
||||
|
||||
@@ -6,9 +6,11 @@ import (
|
||||
"github.com/kardianos/service"
|
||||
)
|
||||
|
||||
// setDependencies sets service dependencies for FreeBSD
|
||||
func setDependencies(svc *service.Config) {
|
||||
// TODO(cuonglm): remove once https://github.com/kardianos/service/issues/359 fixed.
|
||||
_ = os.MkdirAll("/usr/local/etc/rc.d", 0755)
|
||||
}
|
||||
|
||||
// setWorkingDirectory sets the working directory for the service
|
||||
func setWorkingDirectory(svc *service.Config, dir string) {}
|
||||
|
||||
@@ -9,12 +9,10 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/kardianos/service"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld/internal/dns"
|
||||
)
|
||||
|
||||
func init() {
|
||||
if r, err := dns.NewOSConfigurator(func(format string, args ...any) {}, "lo"); err == nil {
|
||||
if r, err := newLoopbackOSConfigurator(); err == nil {
|
||||
useSystemdResolved = r.Mode() == "systemd-resolved"
|
||||
}
|
||||
// Disable quic-go's ECN support by default, see https://github.com/quic-go/quic-go/issues/3911
|
||||
@@ -23,6 +21,7 @@ func init() {
|
||||
}
|
||||
}
|
||||
|
||||
// setDependencies sets service dependencies for Linux
|
||||
func setDependencies(svc *service.Config) {
|
||||
svc.Dependencies = []string{
|
||||
"Wants=network-online.target",
|
||||
@@ -39,6 +38,7 @@ func setDependencies(svc *service.Config) {
|
||||
}
|
||||
}
|
||||
|
||||
// setWorkingDirectory sets the working directory for the service
|
||||
func setWorkingDirectory(svc *service.Config, dir string) {
|
||||
svc.WorkingDirectory = dir
|
||||
}
|
||||
|
||||
33
cmd/cli/prog_log.go
Normal file
33
cmd/cli/prog_log.go
Normal file
@@ -0,0 +1,33 @@
|
||||
package cli
|
||||
|
||||
import "github.com/Control-D-Inc/ctrld"
|
||||
|
||||
// Debug starts a new message with debug level.
|
||||
func (p *prog) Debug() *ctrld.LogEvent {
|
||||
return p.logger.Load().Debug()
|
||||
}
|
||||
|
||||
// Warn starts a new message with warn level.
|
||||
func (p *prog) Warn() *ctrld.LogEvent {
|
||||
return p.logger.Load().Warn()
|
||||
}
|
||||
|
||||
// Info starts a new message with info level.
|
||||
func (p *prog) Info() *ctrld.LogEvent {
|
||||
return p.logger.Load().Info()
|
||||
}
|
||||
|
||||
// Fatal starts a new message with fatal level.
|
||||
func (p *prog) Fatal() *ctrld.LogEvent {
|
||||
return p.logger.Load().Fatal()
|
||||
}
|
||||
|
||||
// Error starts a new message with error level.
|
||||
func (p *prog) Error() *ctrld.LogEvent {
|
||||
return p.logger.Load().Error()
|
||||
}
|
||||
|
||||
// Notice starts a new message with notice level.
|
||||
func (p *prog) Notice() *ctrld.LogEvent {
|
||||
return p.logger.Load().Notice()
|
||||
}
|
||||
@@ -1,11 +1,13 @@
|
||||
//go:build !linux && !freebsd && !darwin
|
||||
//go:build !linux && !freebsd && !darwin && !windows
|
||||
|
||||
package cli
|
||||
|
||||
import "github.com/kardianos/service"
|
||||
|
||||
// setDependencies sets service dependencies for other platforms
|
||||
func setDependencies(svc *service.Config) {}
|
||||
|
||||
// setWorkingDirectory sets the working directory for the service
|
||||
func setWorkingDirectory(svc *service.Config, dir string) {
|
||||
// WorkingDirectory is not supported on Windows.
|
||||
svc.WorkingDirectory = dir
|
||||
|
||||
@@ -4,8 +4,11 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld"
|
||||
"github.com/Masterminds/semver/v3"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld"
|
||||
)
|
||||
|
||||
func Test_prog_dnsWatchdogEnabled(t *testing.T) {
|
||||
@@ -55,3 +58,209 @@ func Test_prog_dnsWatchdogInterval(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_shouldUpgrade(t *testing.T) {
|
||||
// Helper function to create a version
|
||||
makeVersion := func(v string) *semver.Version {
|
||||
ver, err := semver.NewVersion(v)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create version %s: %v", v, err)
|
||||
}
|
||||
return ver
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
versionTarget string
|
||||
currentVersion *semver.Version
|
||||
shouldUpgrade bool
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "empty version target",
|
||||
versionTarget: "",
|
||||
currentVersion: makeVersion("v1.0.0"),
|
||||
shouldUpgrade: false,
|
||||
description: "should skip upgrade when version target is empty",
|
||||
},
|
||||
{
|
||||
name: "invalid version target",
|
||||
versionTarget: "invalid-version",
|
||||
currentVersion: makeVersion("v1.0.0"),
|
||||
shouldUpgrade: false,
|
||||
description: "should skip upgrade when version target is invalid",
|
||||
},
|
||||
{
|
||||
name: "same version",
|
||||
versionTarget: "v1.0.0",
|
||||
currentVersion: makeVersion("v1.0.0"),
|
||||
shouldUpgrade: false,
|
||||
description: "should skip upgrade when target version equals current version",
|
||||
},
|
||||
{
|
||||
name: "older version",
|
||||
versionTarget: "v1.0.0",
|
||||
currentVersion: makeVersion("v1.1.0"),
|
||||
shouldUpgrade: false,
|
||||
description: "should skip upgrade when target version is older than current version",
|
||||
},
|
||||
{
|
||||
name: "patch upgrade allowed",
|
||||
versionTarget: "v1.0.1",
|
||||
currentVersion: makeVersion("v1.0.0"),
|
||||
shouldUpgrade: true,
|
||||
description: "should allow patch version upgrade within same major version",
|
||||
},
|
||||
{
|
||||
name: "minor upgrade allowed",
|
||||
versionTarget: "v1.1.0",
|
||||
currentVersion: makeVersion("v1.0.0"),
|
||||
shouldUpgrade: true,
|
||||
description: "should allow minor version upgrade within same major version",
|
||||
},
|
||||
{
|
||||
name: "major upgrade blocked",
|
||||
versionTarget: "v2.0.0",
|
||||
currentVersion: makeVersion("v1.0.0"),
|
||||
shouldUpgrade: false,
|
||||
description: "should block major version upgrade",
|
||||
},
|
||||
{
|
||||
name: "major downgrade blocked",
|
||||
versionTarget: "v1.0.0",
|
||||
currentVersion: makeVersion("v2.0.0"),
|
||||
shouldUpgrade: false,
|
||||
description: "should block major version downgrade",
|
||||
},
|
||||
{
|
||||
name: "version without v prefix",
|
||||
versionTarget: "1.0.1",
|
||||
currentVersion: makeVersion("v1.0.0"),
|
||||
shouldUpgrade: true,
|
||||
description: "should handle version target without v prefix",
|
||||
},
|
||||
{
|
||||
name: "complex version upgrade allowed",
|
||||
versionTarget: "v1.5.3",
|
||||
currentVersion: makeVersion("v1.4.2"),
|
||||
shouldUpgrade: true,
|
||||
description: "should allow complex version upgrade within same major version",
|
||||
},
|
||||
{
|
||||
name: "complex major upgrade blocked",
|
||||
versionTarget: "v3.1.0",
|
||||
currentVersion: makeVersion("v2.5.3"),
|
||||
shouldUpgrade: false,
|
||||
description: "should block complex major version upgrade",
|
||||
},
|
||||
{
|
||||
name: "pre-release version upgrade allowed",
|
||||
versionTarget: "v1.0.1-beta.1",
|
||||
currentVersion: makeVersion("v1.0.0"),
|
||||
shouldUpgrade: true,
|
||||
description: "should allow pre-release version upgrade within same major version",
|
||||
},
|
||||
{
|
||||
name: "pre-release major upgrade blocked",
|
||||
versionTarget: "v2.0.0-alpha.1",
|
||||
currentVersion: makeVersion("v1.0.0"),
|
||||
shouldUpgrade: false,
|
||||
description: "should block pre-release major version upgrade",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// Create test logger
|
||||
testLogger := &ctrld.Logger{Logger: zap.NewNop()}
|
||||
|
||||
// Call the function and capture the result
|
||||
result := shouldUpgrade(tc.versionTarget, tc.currentVersion, testLogger)
|
||||
|
||||
// Assert the expected result
|
||||
assert.Equal(t, tc.shouldUpgrade, result, tc.description)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_selfUpgradeCheck(t *testing.T) {
|
||||
// Helper function to create a version
|
||||
makeVersion := func(v string) *semver.Version {
|
||||
ver, err := semver.NewVersion(v)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create version %s: %v", v, err)
|
||||
}
|
||||
return ver
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
versionTarget string
|
||||
currentVersion *semver.Version
|
||||
shouldUpgrade bool
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "upgrade allowed",
|
||||
versionTarget: "v1.0.1",
|
||||
currentVersion: makeVersion("v1.0.0"),
|
||||
shouldUpgrade: true,
|
||||
description: "should allow upgrade and attempt to perform it",
|
||||
},
|
||||
{
|
||||
name: "upgrade blocked",
|
||||
versionTarget: "v2.0.0",
|
||||
currentVersion: makeVersion("v1.0.0"),
|
||||
shouldUpgrade: false,
|
||||
description: "should block upgrade and not attempt to perform it",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// Create test logger
|
||||
testLogger := &ctrld.Logger{Logger: zap.NewNop()}
|
||||
|
||||
// Call the function and capture the result
|
||||
result := selfUpgradeCheck(tc.versionTarget, tc.currentVersion, testLogger)
|
||||
|
||||
// Assert the expected result
|
||||
assert.Equal(t, tc.shouldUpgrade, result, tc.description)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_performUpgrade(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
versionTarget string
|
||||
expectedResult bool
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "valid version target",
|
||||
versionTarget: "v1.0.1",
|
||||
expectedResult: true,
|
||||
description: "should attempt to perform upgrade with valid version target",
|
||||
},
|
||||
{
|
||||
name: "empty version target",
|
||||
versionTarget: "",
|
||||
expectedResult: true,
|
||||
description: "should attempt to perform upgrade even with empty version target",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// Create test logger
|
||||
testLogger := &ctrld.Logger{Logger: zap.NewNop()}
|
||||
// Call the function and capture the result
|
||||
result := performUpgrade(tc.versionTarget, testLogger)
|
||||
assert.Equal(t, tc.expectedResult, result, tc.description)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
12
cmd/cli/prog_windows.go
Normal file
12
cmd/cli/prog_windows.go
Normal file
@@ -0,0 +1,12 @@
|
||||
package cli
|
||||
|
||||
import "github.com/kardianos/service"
|
||||
|
||||
// setDependencies sets service dependencies for Windows
|
||||
func setDependencies(svc *service.Config) {}
|
||||
|
||||
// setWorkingDirectory sets the working directory for the service
|
||||
func setWorkingDirectory(svc *service.Config, dir string) {
|
||||
// WorkingDirectory is not supported on Windows.
|
||||
svc.WorkingDirectory = dir
|
||||
}
|
||||
@@ -2,6 +2,8 @@ package cli
|
||||
|
||||
import "github.com/prometheus/client_golang/prometheus"
|
||||
|
||||
// Prometheus metrics label constants for consistent labeling across all metrics
|
||||
// These ensure standardized metric labeling for monitoring and alerting
|
||||
const (
|
||||
metricsLabelListener = "listener"
|
||||
metricsLabelClientSourceIP = "client_source_ip"
|
||||
@@ -13,17 +15,21 @@ const (
|
||||
)
|
||||
|
||||
// statsVersion represent ctrld version.
|
||||
// This metric provides version information for monitoring and debugging
|
||||
var statsVersion = prometheus.NewCounterVec(prometheus.CounterOpts{
|
||||
Name: "ctrld_build_info",
|
||||
Help: "Version of ctrld process.",
|
||||
}, []string{"gitref", "goversion", "version"})
|
||||
|
||||
// statsTimeStart represents start time of ctrld service.
|
||||
// This metric tracks service uptime and helps with monitoring service restarts
|
||||
var statsTimeStart = prometheus.NewGauge(prometheus.GaugeOpts{
|
||||
Name: "ctrld_time_seconds",
|
||||
Help: "Start time of the ctrld process since unix epoch in seconds.",
|
||||
})
|
||||
|
||||
// statsQueriesCountLabels defines the labels for query count metrics
|
||||
// These labels provide detailed breakdown of DNS query statistics
|
||||
var statsQueriesCountLabels = []string{
|
||||
metricsLabelListener,
|
||||
metricsLabelClientSourceIP,
|
||||
@@ -35,6 +41,7 @@ var statsQueriesCountLabels = []string{
|
||||
}
|
||||
|
||||
// statsQueriesCount counts total number of queries.
|
||||
// This provides comprehensive DNS query statistics for monitoring and alerting
|
||||
var statsQueriesCount = prometheus.NewCounterVec(prometheus.CounterOpts{
|
||||
Name: "ctrld_queries_count",
|
||||
Help: "Total number of queries.",
|
||||
@@ -44,12 +51,14 @@ var statsQueriesCount = prometheus.NewCounterVec(prometheus.CounterOpts{
|
||||
//
|
||||
// The labels "client_source_ip", "client_mac", "client_hostname" are unbounded,
|
||||
// thus this stat is highly inefficient if there are many devices.
|
||||
// This metric should be used carefully in high-client environments
|
||||
var statsClientQueriesCount = prometheus.NewCounterVec(prometheus.CounterOpts{
|
||||
Name: "ctrld_client_queries_count",
|
||||
Help: "Total number queries of a client.",
|
||||
}, []string{metricsLabelClientSourceIP, metricsLabelClientMac, metricsLabelClientHostname})
|
||||
|
||||
// WithLabelValuesInc increases prometheus counter by 1 if query stats is enabled.
|
||||
// This provides conditional metric collection to avoid performance impact when metrics are disabled
|
||||
func (p *prog) WithLabelValuesInc(c *prometheus.CounterVec, lvs ...string) {
|
||||
if p.metricsQueryStats.Load() {
|
||||
c.WithLabelValues(lvs...).Inc()
|
||||
|
||||
@@ -8,10 +8,12 @@ import (
|
||||
"syscall"
|
||||
)
|
||||
|
||||
// notifyReloadSigCh sends reload signal to the channel
|
||||
func notifyReloadSigCh(ch chan os.Signal) {
|
||||
signal.Notify(ch, syscall.SIGUSR1)
|
||||
}
|
||||
|
||||
// sendReloadSignal sends a reload signal to the current process
|
||||
func (p *prog) sendReloadSignal() error {
|
||||
return syscall.Kill(syscall.Getpid(), syscall.SIGUSR1)
|
||||
}
|
||||
|
||||
@@ -6,8 +6,10 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// notifyReloadSigCh is a no-op on Windows platforms
|
||||
func notifyReloadSigCh(ch chan os.Signal) {}
|
||||
|
||||
// sendReloadSignal sends a reload signal to the program
|
||||
func (p *prog) sendReloadSignal() error {
|
||||
select {
|
||||
case p.reloadCh <- struct{}{}:
|
||||
|
||||
@@ -4,31 +4,44 @@ import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"github.com/fsnotify/fsnotify"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld/internal/resolvconffile"
|
||||
)
|
||||
|
||||
// parseResolvConfNameservers reads the resolv.conf file and returns the nameservers found.
|
||||
// Returns nil if no nameservers are found.
|
||||
// This function parses the system DNS configuration to understand current nameserver settings
|
||||
func (p *prog) parseResolvConfNameservers(path string) ([]string, error) {
|
||||
return resolvconffile.NameserversFromFile(path)
|
||||
}
|
||||
|
||||
// watchResolvConf watches any changes to /etc/resolv.conf file,
|
||||
// and reverting to the original config set by ctrld.
|
||||
// This ensures that DNS settings are not overridden by other applications or system processes
|
||||
func (p *prog) watchResolvConf(iface *net.Interface, ns []netip.Addr, setDnsFn func(iface *net.Interface, ns []netip.Addr) error) {
|
||||
resolvConfPath := "/etc/resolv.conf"
|
||||
// Evaluating symbolics link to watch the target file that /etc/resolv.conf point to.
|
||||
// This handles systems where resolv.conf is a symlink to another location
|
||||
if rp, _ := filepath.EvalSymlinks(resolvConfPath); rp != "" {
|
||||
resolvConfPath = rp
|
||||
}
|
||||
mainLog.Load().Debug().Msgf("start watching %s file", resolvConfPath)
|
||||
p.Debug().Msgf("Start watching %s file", resolvConfPath)
|
||||
watcher, err := fsnotify.NewWatcher()
|
||||
if err != nil {
|
||||
mainLog.Load().Warn().Err(err).Msg("could not create watcher for /etc/resolv.conf")
|
||||
p.Warn().Err(err).Msg("Could not create watcher for /etc/resolv.conf")
|
||||
return
|
||||
}
|
||||
defer watcher.Close()
|
||||
|
||||
// We watch /etc instead of /etc/resolv.conf directly,
|
||||
// see: https://github.com/fsnotify/fsnotify#watching-a-file-doesnt-work-well
|
||||
// This is necessary because some systems don't properly notify on file changes
|
||||
watchDir := filepath.Dir(resolvConfPath)
|
||||
if err := watcher.Add(watchDir); err != nil {
|
||||
mainLog.Load().Warn().Err(err).Msgf("could not add %s to watcher list", watchDir)
|
||||
p.Warn().Err(err).Msgf("Could not add %s to watcher list", watchDir)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -37,9 +50,12 @@ func (p *prog) watchResolvConf(iface *net.Interface, ns []netip.Addr, setDnsFn f
|
||||
case <-p.dnsWatcherStopCh:
|
||||
return
|
||||
case <-p.stopCh:
|
||||
mainLog.Load().Debug().Msgf("stopping watcher for %s", resolvConfPath)
|
||||
p.Debug().Msgf("Stopping watcher for %s", resolvConfPath)
|
||||
return
|
||||
case event, ok := <-watcher.Events:
|
||||
if p.recoveryRunning.Load() {
|
||||
return
|
||||
}
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
@@ -47,24 +63,91 @@ func (p *prog) watchResolvConf(iface *net.Interface, ns []netip.Addr, setDnsFn f
|
||||
continue
|
||||
}
|
||||
if event.Has(fsnotify.Write) || event.Has(fsnotify.Create) {
|
||||
mainLog.Load().Debug().Msg("/etc/resolv.conf changes detected, reverting to ctrld setting")
|
||||
if err := watcher.Remove(watchDir); err != nil {
|
||||
mainLog.Load().Error().Err(err).Msg("failed to pause watcher")
|
||||
continue
|
||||
p.Debug().Msgf("/etc/resolv.conf changes detected, reading changes...")
|
||||
|
||||
// Convert expected nameservers to strings for comparison
|
||||
// This allows us to detect when the resolv.conf has been modified
|
||||
expectedNS := make([]string, len(ns))
|
||||
for i, addr := range ns {
|
||||
expectedNS[i] = addr.String()
|
||||
}
|
||||
if err := setDnsFn(iface, ns); err != nil {
|
||||
mainLog.Load().Error().Err(err).Msg("failed to revert /etc/resolv.conf changes")
|
||||
|
||||
var foundNS []string
|
||||
var err error
|
||||
|
||||
maxRetries := 1
|
||||
for retry := 0; retry < maxRetries; retry++ {
|
||||
foundNS, err = p.parseResolvConfNameservers(resolvConfPath)
|
||||
if err != nil {
|
||||
p.Error().Err(err).Msg("Failed to read resolv.conf content")
|
||||
break
|
||||
}
|
||||
|
||||
// If we found nameservers, break out of retry loop
|
||||
// This handles cases where the file is being written but not yet complete
|
||||
if len(foundNS) > 0 {
|
||||
break
|
||||
}
|
||||
|
||||
// Only retry if we found no nameservers
|
||||
// This handles temporary file states during updates
|
||||
if retry < maxRetries-1 {
|
||||
p.Debug().Msgf("resolv.conf has no nameserver entries, retry %d/%d in 2 seconds", retry+1, maxRetries)
|
||||
select {
|
||||
case <-p.stopCh:
|
||||
return
|
||||
case <-p.dnsWatcherStopCh:
|
||||
return
|
||||
case <-time.After(2 * time.Second):
|
||||
continue
|
||||
}
|
||||
} else {
|
||||
p.Debug().Msg("resolv.conf remained empty after all retries")
|
||||
}
|
||||
}
|
||||
if err := watcher.Add(watchDir); err != nil {
|
||||
mainLog.Load().Error().Err(err).Msg("failed to continue running watcher")
|
||||
return
|
||||
|
||||
// If we found nameservers, check if they match what we expect
|
||||
if len(foundNS) > 0 {
|
||||
// Check if the nameservers match exactly what we expect
|
||||
matches := len(foundNS) == len(expectedNS)
|
||||
if matches {
|
||||
for i := range foundNS {
|
||||
if foundNS[i] != expectedNS[i] {
|
||||
matches = false
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
p.Debug().
|
||||
Strs("found", foundNS).
|
||||
Strs("expected", expectedNS).
|
||||
Bool("matches", matches).
|
||||
Msg("checking nameservers")
|
||||
|
||||
// Only revert if the nameservers don't match
|
||||
if !matches {
|
||||
if err := watcher.Remove(watchDir); err != nil {
|
||||
p.Error().Err(err).Msg("Failed to pause watcher")
|
||||
continue
|
||||
}
|
||||
|
||||
if err := setDnsFn(iface, ns); err != nil {
|
||||
p.Error().Err(err).Msg("Failed to revert /etc/resolv.conf changes")
|
||||
}
|
||||
|
||||
if err := watcher.Add(watchDir); err != nil {
|
||||
p.Error().Err(err).Msg("Failed to continue running watcher")
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
case err, ok := <-watcher.Errors:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
mainLog.Load().Err(err).Msg("could not get event for /etc/resolv.conf")
|
||||
p.Error().Err(err).Msg("Could not get event for /etc/resolv.conf")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -12,7 +12,7 @@ import (
|
||||
const resolvConfPath = "/etc/resolv.conf"
|
||||
|
||||
// setResolvConf sets the content of resolv.conf file using the given nameservers list.
|
||||
func setResolvConf(iface *net.Interface, ns []netip.Addr) error {
|
||||
func (p *prog) setResolvConf(iface *net.Interface, ns []netip.Addr) error {
|
||||
servers := make([]string, len(ns))
|
||||
for i := range ns {
|
||||
servers[i] = ns[i].String()
|
||||
|
||||
@@ -6,14 +6,16 @@ import (
|
||||
"net"
|
||||
"net/netip"
|
||||
|
||||
"tailscale.com/control/controlknobs"
|
||||
"tailscale.com/health"
|
||||
"tailscale.com/util/dnsname"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld/internal/dns"
|
||||
)
|
||||
|
||||
// setResolvConf sets the content of resolv.conf file using the given nameservers list.
|
||||
func setResolvConf(iface *net.Interface, ns []netip.Addr) error {
|
||||
r, err := dns.NewOSConfigurator(func(format string, args ...any) {}, "lo") // interface name does not matter.
|
||||
// setResolvConf sets the content of the resolv.conf file using the given nameservers list.
|
||||
func (p *prog) setResolvConf(iface *net.Interface, ns []netip.Addr) error {
|
||||
r, err := newLoopbackOSConfigurator()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -22,12 +24,17 @@ func setResolvConf(iface *net.Interface, ns []netip.Addr) error {
|
||||
Nameservers: ns,
|
||||
SearchDomains: []dnsname.FQDN{},
|
||||
}
|
||||
if sds, err := searchDomains(); err == nil {
|
||||
oc.SearchDomains = sds
|
||||
} else {
|
||||
p.Debug().Err(err).Msg("Failed to get search domains list when reverting resolv.conf file")
|
||||
}
|
||||
return r.SetDNS(oc)
|
||||
}
|
||||
|
||||
// shouldWatchResolvconf reports whether ctrld should watch changes to resolv.conf file with given OS configurator.
|
||||
func shouldWatchResolvconf() bool {
|
||||
r, err := dns.NewOSConfigurator(func(format string, args ...any) {}, "lo") // interface name does not matter.
|
||||
r, err := newLoopbackOSConfigurator()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
@@ -38,3 +45,8 @@ func shouldWatchResolvconf() bool {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// newLoopbackOSConfigurator creates an OSConfigurator for DNS management using the "lo" interface.
|
||||
func newLoopbackOSConfigurator() (dns.OSConfigurator, error) {
|
||||
return dns.NewOSConfigurator(noopLogf, &health.Tracker{}, &controlknobs.Knobs{}, "lo")
|
||||
}
|
||||
|
||||
51
cmd/cli/resolvconf_test.go
Normal file
51
cmd/cli/resolvconf_test.go
Normal file
@@ -0,0 +1,51 @@
|
||||
//go:build unix
|
||||
|
||||
package cli
|
||||
|
||||
import (
|
||||
"os"
|
||||
"slices"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld/internal/dns/resolvconffile"
|
||||
)
|
||||
|
||||
func oldParseResolvConfNameservers(path string) ([]string, error) {
|
||||
content, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Parse the file for "nameserver" lines
|
||||
var currentNS []string
|
||||
lines := strings.Split(string(content), "\n")
|
||||
for _, line := range lines {
|
||||
trimmed := strings.TrimSpace(line)
|
||||
if strings.HasPrefix(trimmed, "nameserver") {
|
||||
parts := strings.Fields(trimmed)
|
||||
if len(parts) >= 2 {
|
||||
currentNS = append(currentNS, parts[1])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return currentNS, nil
|
||||
}
|
||||
|
||||
// Test_prog_parseResolvConfNameservers tests the parsing of nameservers from resolv.conf content.
|
||||
// Note: The previous implementation was removed to reduce code duplication and consolidate
|
||||
// the resolv.conf handling logic into a single unified approach. All resolv.conf parsing
|
||||
// is now handled by the resolvconffile package, which provides a consistent interface
|
||||
// for both reading and modifying resolv.conf files across different platforms.
|
||||
func Test_prog_parseResolvConfNameservers(t *testing.T) {
|
||||
oldNss, _ := oldParseResolvConfNameservers(resolvconffile.Path)
|
||||
p := &prog{}
|
||||
nss, _ := p.parseResolvConfNameservers(resolvconffile.Path)
|
||||
slices.Sort(oldNss)
|
||||
slices.Sort(nss)
|
||||
if !slices.Equal(oldNss, nss) {
|
||||
t.Errorf("result mismatched, old: %v, new: %v", oldNss, nss)
|
||||
}
|
||||
t.Logf("result: %v", nss)
|
||||
}
|
||||
@@ -6,7 +6,7 @@ import (
|
||||
)
|
||||
|
||||
// setResolvConf sets the content of resolv.conf file using the given nameservers list.
|
||||
func setResolvConf(_ *net.Interface, _ []netip.Addr) error {
|
||||
func (p *prog) setResolvConf(_ *net.Interface, _ []netip.Addr) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
14
cmd/cli/search_domains_unix.go
Normal file
14
cmd/cli/search_domains_unix.go
Normal file
@@ -0,0 +1,14 @@
|
||||
//go:build unix
|
||||
|
||||
package cli
|
||||
|
||||
import (
|
||||
"tailscale.com/util/dnsname"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld/internal/resolvconffile"
|
||||
)
|
||||
|
||||
// searchDomains returns the current search domains config.
|
||||
func searchDomains() ([]dnsname.FQDN, error) {
|
||||
return resolvconffile.SearchDomains()
|
||||
}
|
||||
43
cmd/cli/search_domains_windows.go
Normal file
43
cmd/cli/search_domains_windows.go
Normal file
@@ -0,0 +1,43 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"syscall"
|
||||
|
||||
"golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
|
||||
"tailscale.com/util/dnsname"
|
||||
)
|
||||
|
||||
// searchDomains returns the current search domains config.
|
||||
func searchDomains() ([]dnsname.FQDN, error) {
|
||||
flags := winipcfg.GAAFlagIncludeGateways |
|
||||
winipcfg.GAAFlagIncludePrefix
|
||||
|
||||
aas, err := winipcfg.GetAdaptersAddresses(syscall.AF_UNSPEC, flags)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("winipcfg.GetAdaptersAddresses: %w", err)
|
||||
}
|
||||
|
||||
var sds []dnsname.FQDN
|
||||
for _, aa := range aas {
|
||||
if aa.OperStatus != winipcfg.IfOperStatusUp {
|
||||
continue
|
||||
}
|
||||
|
||||
// Skip if software loopback or other non-physical types
|
||||
// This is to avoid the "Loopback Pseudo-Interface 1" issue we see on windows
|
||||
if aa.IfType == winipcfg.IfTypeSoftwareLoopback {
|
||||
continue
|
||||
}
|
||||
|
||||
for a := aa.FirstDNSSuffix; a != nil; a = a.Next {
|
||||
d, err := dnsname.ToFQDN(a.String())
|
||||
if err != nil {
|
||||
mainLog.Load().Debug().Err(err).Msgf("Failed to parse domain: %s", a.String())
|
||||
continue
|
||||
}
|
||||
sds = append(sds, d)
|
||||
}
|
||||
}
|
||||
return sds, nil
|
||||
}
|
||||
@@ -4,4 +4,5 @@ package cli
|
||||
|
||||
var supportedSelfDelete = true
|
||||
|
||||
// selfDeleteExe performs self-deletion on non-Windows platforms
|
||||
func selfDeleteExe() error { return nil }
|
||||
|
||||
@@ -33,6 +33,7 @@ type FILE_DISPOSITION_INFO struct {
|
||||
DeleteFile bool
|
||||
}
|
||||
|
||||
// dsOpenHandle opens a handle to the specified file with DELETE access
|
||||
func dsOpenHandle(pwPath *uint16) (windows.Handle, error) {
|
||||
handle, err := windows.CreateFile(
|
||||
pwPath,
|
||||
@@ -51,6 +52,7 @@ func dsOpenHandle(pwPath *uint16) (windows.Handle, error) {
|
||||
return handle, nil
|
||||
}
|
||||
|
||||
// dsRenameHandle renames a file handle to a stream name
|
||||
func dsRenameHandle(hHandle windows.Handle) error {
|
||||
var fRename FILE_RENAME_INFO
|
||||
DS_STREAM_RENAME, err := windows.UTF16FromString(":deadbeef")
|
||||
@@ -82,6 +84,7 @@ func dsRenameHandle(hHandle windows.Handle) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// dsDepositeHandle marks a file handle for deletion
|
||||
func dsDepositeHandle(hHandle windows.Handle) error {
|
||||
var fDelete FILE_DISPOSITION_INFO
|
||||
fDelete.DeleteFile = true
|
||||
@@ -100,6 +103,7 @@ func dsDepositeHandle(hHandle windows.Handle) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// selfDeleteExe performs self-deletion on Windows platforms
|
||||
func selfDeleteExe() error {
|
||||
var wcPath [windows.MAX_PATH + 1]uint16
|
||||
var hCurrent windows.Handle
|
||||
|
||||
@@ -5,12 +5,13 @@ package cli
|
||||
import (
|
||||
"os"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/Control-D-Inc/ctrld"
|
||||
)
|
||||
|
||||
func selfUninstall(p *prog, logger zerolog.Logger) {
|
||||
// selfUninstall performs self-uninstallation on non-Unix platforms
|
||||
func selfUninstall(p *prog, logger *ctrld.Logger) {
|
||||
if uninstallInvalidCdUID(p, logger, false) {
|
||||
logger.Warn().Msgf("service was uninstalled because device %q does not exist", cdUID)
|
||||
logger.Warn().Msgf("Service was uninstalled because device %q does not exist", cdUID)
|
||||
os.Exit(0)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -9,37 +9,39 @@ import (
|
||||
"runtime"
|
||||
"syscall"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/Control-D-Inc/ctrld"
|
||||
)
|
||||
|
||||
func selfUninstall(p *prog, logger zerolog.Logger) {
|
||||
// selfUninstall performs self-uninstallation on Unix platforms
|
||||
func selfUninstall(p *prog, logger *ctrld.Logger) {
|
||||
if runtime.GOOS == "linux" {
|
||||
selfUninstallLinux(p, logger)
|
||||
}
|
||||
|
||||
bin, err := os.Executable()
|
||||
if err != nil {
|
||||
logger.Fatal().Err(err).Msg("could not determine executable")
|
||||
logger.Fatal().Err(err).Msg("Could not determine executable")
|
||||
}
|
||||
args := []string{"uninstall"}
|
||||
if !deactivationPinNotSet() {
|
||||
args = append(args, fmt.Sprintf("--pin=%d", cdDeactivationPin))
|
||||
if deactivationPinSet() {
|
||||
args = append(args, fmt.Sprintf("--pin=%d", cdDeactivationPin.Load()))
|
||||
}
|
||||
cmd := exec.Command(bin, args...)
|
||||
cmd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true}
|
||||
if err := cmd.Start(); err != nil {
|
||||
logger.Fatal().Err(err).Msg("could not start self uninstall command")
|
||||
logger.Fatal().Err(err).Msg("Could not start self uninstall command")
|
||||
}
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
logger.Warn().Msgf("service was uninstalled because device %q does not exist", cdUID)
|
||||
logger.Warn().Msgf("Service was uninstalled because device %q does not exist", cdUID)
|
||||
_ = cmd.Wait()
|
||||
os.Exit(0)
|
||||
}
|
||||
|
||||
func selfUninstallLinux(p *prog, logger zerolog.Logger) {
|
||||
// selfUninstallLinux performs self-uninstallation on Linux platforms
|
||||
func selfUninstallLinux(p *prog, logger *ctrld.Logger) {
|
||||
if uninstallInvalidCdUID(p, logger, true) {
|
||||
logger.Warn().Msgf("service was uninstalled because device %q does not exist", cdUID)
|
||||
logger.Warn().Msgf("Service was uninstalled because device %q does not exist", cdUID)
|
||||
os.Exit(0)
|
||||
}
|
||||
}
|
||||
|
||||
12
cmd/cli/self_upgrade_others.go
Normal file
12
cmd/cli/self_upgrade_others.go
Normal file
@@ -0,0 +1,12 @@
|
||||
//go:build !windows
|
||||
|
||||
package cli
|
||||
|
||||
import (
|
||||
"syscall"
|
||||
)
|
||||
|
||||
// sysProcAttrForDetachedChildProcess returns *syscall.SysProcAttr instance for running a detached child command.
|
||||
func sysProcAttrForDetachedChildProcess() *syscall.SysProcAttr {
|
||||
return &syscall.SysProcAttr{Setsid: true}
|
||||
}
|
||||
18
cmd/cli/self_upgrade_windows.go
Normal file
18
cmd/cli/self_upgrade_windows.go
Normal file
@@ -0,0 +1,18 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"syscall"
|
||||
)
|
||||
|
||||
// From: https://learn.microsoft.com/en-us/windows/win32/procthread/process-creation-flags?redirectedfrom=MSDN
|
||||
|
||||
// SYSCALL_CREATE_NO_WINDOW set flag to run process without a console window.
|
||||
const SYSCALL_CREATE_NO_WINDOW = 0x08000000
|
||||
|
||||
// sysProcAttrForDetachedChildProcess returns *syscall.SysProcAttr instance for running self-upgrade command.
|
||||
func sysProcAttrForDetachedChildProcess() *syscall.SysProcAttr {
|
||||
return &syscall.SysProcAttr{
|
||||
CreationFlags: syscall.CREATE_NEW_PROCESS_GROUP | SYSCALL_CREATE_NO_WINDOW,
|
||||
HideWindow: true,
|
||||
}
|
||||
}
|
||||
@@ -1,24 +1,31 @@
|
||||
package cli
|
||||
|
||||
// semaphore provides a simple synchronization mechanism
|
||||
type semaphore interface {
|
||||
acquire()
|
||||
release()
|
||||
}
|
||||
|
||||
// noopSemaphore is a no-operation implementation of semaphore
|
||||
type noopSemaphore struct{}
|
||||
|
||||
// acquire performs a no-operation for the noop semaphore
|
||||
func (n noopSemaphore) acquire() {}
|
||||
|
||||
// release performs a no-operation for the noop semaphore
|
||||
func (n noopSemaphore) release() {}
|
||||
|
||||
// chanSemaphore is a channel-based implementation of semaphore
|
||||
type chanSemaphore struct {
|
||||
ready chan struct{}
|
||||
}
|
||||
|
||||
// acquire blocks until a slot is available in the semaphore
|
||||
func (c *chanSemaphore) acquire() {
|
||||
c.ready <- struct{}{}
|
||||
}
|
||||
|
||||
// release signals that a slot has been freed in the semaphore
|
||||
func (c *chanSemaphore) release() {
|
||||
<-c.ready
|
||||
}
|
||||
|
||||
@@ -4,12 +4,13 @@ import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"os/exec"
|
||||
"runtime"
|
||||
|
||||
"github.com/coreos/go-systemd/v22/unit"
|
||||
"github.com/kardianos/service"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld/internal/router"
|
||||
)
|
||||
|
||||
// newService wraps service.New call to return service.Service
|
||||
@@ -20,10 +21,6 @@ func newService(i service.Interface, c *service.Config) (service.Service, error)
|
||||
return nil, err
|
||||
}
|
||||
switch {
|
||||
case router.IsOldOpenwrt(), router.IsNetGearOrbi():
|
||||
return &procd{sysV: &sysV{s}, svcConfig: c}, nil
|
||||
case router.IsGLiNet():
|
||||
return &sysV{s}, nil
|
||||
case s.Platform() == "unix-systemv":
|
||||
return &sysV{s}, nil
|
||||
case s.Platform() == "linux-systemd":
|
||||
@@ -38,7 +35,7 @@ func newService(i service.Interface, c *service.Config) (service.Service, error)
|
||||
// sysV wraps a service.Service, and provide start/stop/status command
|
||||
// base on "/etc/init.d/<service_name>".
|
||||
//
|
||||
// Use this on system where "service" command is not available, like GL.iNET router.
|
||||
// Use this on system where "service" command is not available.
|
||||
type sysV struct {
|
||||
service.Service
|
||||
}
|
||||
@@ -85,37 +82,6 @@ func (s *sysV) Status() (service.Status, error) {
|
||||
return unixSystemVServiceStatus()
|
||||
}
|
||||
|
||||
// procd wraps a service.Service, and provide start/stop command
|
||||
// base on "/etc/init.d/<service_name>", status command base on parsing "ps" command output.
|
||||
//
|
||||
// Use this on system where "/etc/init.d/<service_name> status" command is not available,
|
||||
// like old GL.iNET Opal router.
|
||||
type procd struct {
|
||||
*sysV
|
||||
svcConfig *service.Config
|
||||
}
|
||||
|
||||
func (s *procd) Status() (service.Status, error) {
|
||||
if !s.installed() {
|
||||
return service.StatusUnknown, service.ErrNotInstalled
|
||||
}
|
||||
bin := s.svcConfig.Executable
|
||||
if bin == "" {
|
||||
exe, err := os.Executable()
|
||||
if err != nil {
|
||||
return service.StatusUnknown, nil
|
||||
}
|
||||
bin = exe
|
||||
}
|
||||
|
||||
// Looking for something like "/sbin/ctrld run ".
|
||||
shellCmd := fmt.Sprintf("ps | grep -q %q", bin+" [r]un ")
|
||||
if err := exec.Command("sh", "-c", shellCmd).Run(); err != nil {
|
||||
return service.StatusStopped, nil
|
||||
}
|
||||
return service.StatusRunning, nil
|
||||
}
|
||||
|
||||
// systemd wraps a service.Service, and provide status command to
|
||||
// report the status correctly.
|
||||
type systemd struct {
|
||||
@@ -130,6 +96,60 @@ func (s *systemd) Status() (service.Status, error) {
|
||||
return s.Service.Status()
|
||||
}
|
||||
|
||||
func (s *systemd) Start() error {
|
||||
const systemdUnitFile = "/etc/systemd/system/ctrld.service"
|
||||
f, err := os.Open(systemdUnitFile)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer f.Close()
|
||||
if opts, change := ensureSystemdKillMode(f); change {
|
||||
mode := os.FileMode(0644)
|
||||
buf, err := io.ReadAll(unit.Serialize(opts))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := os.WriteFile(systemdUnitFile, buf, mode); err != nil {
|
||||
return err
|
||||
}
|
||||
if out, err := exec.Command("systemctl", "daemon-reload").CombinedOutput(); err != nil {
|
||||
return fmt.Errorf("systemctl daemon-reload failed: %w\n%s", err, string(out))
|
||||
}
|
||||
mainLog.Load().Debug().Msg("Set KillMode=process successfully")
|
||||
}
|
||||
return s.Service.Start()
|
||||
}
|
||||
|
||||
// ensureSystemdKillMode ensure systemd unit file is configured with KillMode=process.
|
||||
// This is necessary for running self-upgrade flow.
|
||||
func ensureSystemdKillMode(r io.Reader) (opts []*unit.UnitOption, change bool) {
|
||||
opts, err := unit.DeserializeOptions(r)
|
||||
if err != nil {
|
||||
mainLog.Load().Error().Err(err).Msg("Failed to deserialize options")
|
||||
return
|
||||
}
|
||||
change = true
|
||||
needKillModeOpt := true
|
||||
killModeOpt := unit.NewUnitOption("Service", "KillMode", "process")
|
||||
for _, opt := range opts {
|
||||
if opt.Match(killModeOpt) {
|
||||
needKillModeOpt = false
|
||||
change = false
|
||||
break
|
||||
}
|
||||
if opt.Section == killModeOpt.Section && opt.Name == killModeOpt.Name {
|
||||
opt.Value = killModeOpt.Value
|
||||
needKillModeOpt = false
|
||||
break
|
||||
}
|
||||
}
|
||||
if needKillModeOpt {
|
||||
opts = append(opts, killModeOpt)
|
||||
}
|
||||
return opts, change
|
||||
}
|
||||
|
||||
// newLaunchd creates a new launchd service wrapper
|
||||
func newLaunchd(s service.Service) *launchd {
|
||||
return &launchd{
|
||||
Service: s,
|
||||
@@ -156,26 +176,33 @@ func (l *launchd) Status() (service.Status, error) {
|
||||
type task struct {
|
||||
f func() error
|
||||
abortOnError bool
|
||||
Name string
|
||||
}
|
||||
|
||||
// doTasks executes a list of tasks and returns success status
|
||||
func doTasks(tasks []task) bool {
|
||||
var prevErr error
|
||||
for _, task := range tasks {
|
||||
mainLog.Load().Debug().Msgf("Running task %s", task.Name)
|
||||
if err := task.f(); err != nil {
|
||||
if task.abortOnError {
|
||||
mainLog.Load().Error().Msg(errors.Join(prevErr, err).Error())
|
||||
mainLog.Load().Error().Msgf("Error running task %s: %v", task.Name, err)
|
||||
return false
|
||||
}
|
||||
prevErr = err
|
||||
// if this is darwin stop command, dont print debug
|
||||
// since launchctl complains on every start
|
||||
if runtime.GOOS != "darwin" || task.Name != "Stop" {
|
||||
mainLog.Load().Debug().Msgf("Error running task %s: %v", task.Name, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// checkHasElevatedPrivilege checks if the process has elevated privileges and exits if not
|
||||
func checkHasElevatedPrivilege() {
|
||||
ok, err := hasElevatedPrivilege()
|
||||
if err != nil {
|
||||
mainLog.Load().Error().Msgf("could not detect user privilege: %v", err)
|
||||
mainLog.Load().Error().Msgf("Could not detect user privilege: %v", err)
|
||||
return
|
||||
}
|
||||
if !ok {
|
||||
@@ -184,6 +211,7 @@ func checkHasElevatedPrivilege() {
|
||||
}
|
||||
}
|
||||
|
||||
// unixSystemVServiceStatus checks the status of a Unix System V service
|
||||
func unixSystemVServiceStatus() (service.Status, error) {
|
||||
out, err := exec.Command("/etc/init.d/ctrld", "status").CombinedOutput()
|
||||
if err != nil {
|
||||
|
||||
@@ -6,10 +6,15 @@ import (
|
||||
"os"
|
||||
)
|
||||
|
||||
// hasElevatedPrivilege checks if the current process has elevated privileges
|
||||
func hasElevatedPrivilege() (bool, error) {
|
||||
return os.Geteuid() == 0, nil
|
||||
}
|
||||
|
||||
// openLogFile opens a log file with the specified flags
|
||||
func openLogFile(path string, flags int) (*os.File, error) {
|
||||
return os.OpenFile(path, flags, os.FileMode(0o600))
|
||||
}
|
||||
|
||||
// ConfigureWindowsServiceFailureActions is a no-op on non-Windows platforms
|
||||
func ConfigureWindowsServiceFailureActions(serviceName string) error { return nil }
|
||||
|
||||
28
cmd/cli/service_test.go
Normal file
28
cmd/cli/service_test.go
Normal file
@@ -0,0 +1,28 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func Test_ensureSystemdKillMode(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
unitFile string
|
||||
wantChange bool
|
||||
}{
|
||||
{"no KillMode", "[Service]\nExecStart=/bin/sleep 1", true},
|
||||
{"not KillMode=process", "[Service]\nExecStart=/bin/sleep 1\nKillMode=mixed", true},
|
||||
{"KillMode=process", "[Service]\nExecStart=/bin/sleep 1\nKillMode=process", false},
|
||||
{"invalid unit file", "[Service\nExecStart=/bin/sleep 1\nKillMode=process", false},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
if _, change := ensureSystemdKillMode(strings.NewReader(tc.unitFile)); tc.wantChange != change {
|
||||
t.Errorf("ensureSystemdKillMode(%q) = %v, want %v", tc.unitFile, change, tc.wantChange)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -2,11 +2,16 @@ package cli
|
||||
|
||||
import (
|
||||
"os"
|
||||
"runtime"
|
||||
"syscall"
|
||||
"time"
|
||||
"unsafe"
|
||||
|
||||
"golang.org/x/sys/windows"
|
||||
"golang.org/x/sys/windows/svc/mgr"
|
||||
)
|
||||
|
||||
// hasElevatedPrivilege checks if the current process has elevated privileges on Windows
|
||||
func hasElevatedPrivilege() (bool, error) {
|
||||
var sid *windows.SID
|
||||
if err := windows.AllocateAndInitializeSid(
|
||||
@@ -28,6 +33,68 @@ func hasElevatedPrivilege() (bool, error) {
|
||||
return token.IsMember(sid)
|
||||
}
|
||||
|
||||
// ConfigureWindowsServiceFailureActions checks if the given service
|
||||
// has the correct failure actions configured, and updates them if not.
|
||||
func ConfigureWindowsServiceFailureActions(serviceName string) error {
|
||||
if runtime.GOOS != "windows" {
|
||||
return nil // no-op on non-Windows
|
||||
}
|
||||
|
||||
m, err := mgr.Connect()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer m.Disconnect()
|
||||
|
||||
s, err := m.OpenService(serviceName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer s.Close()
|
||||
|
||||
// 1. Retrieve the current config
|
||||
cfg, err := s.Config()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 2. Update the Description
|
||||
cfg.Description = "A highly configurable, multi-protocol DNS forwarding proxy"
|
||||
|
||||
// 3. Apply the updated config
|
||||
if err := s.UpdateConfig(cfg); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Then proceed with existing actions, e.g. setting failure actions
|
||||
actions := []mgr.RecoveryAction{
|
||||
{Type: mgr.ServiceRestart, Delay: time.Second * 5}, // 5 seconds
|
||||
{Type: mgr.ServiceRestart, Delay: time.Second * 5}, // 5 seconds
|
||||
{Type: mgr.ServiceRestart, Delay: time.Second * 5}, // 5 seconds
|
||||
}
|
||||
|
||||
// Set the recovery actions (3 restarts, reset period = 120).
|
||||
err = s.SetRecoveryActions(actions, 120)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Ensure that failure actions are NOT triggered on user-initiated stops.
|
||||
var failureActionsFlag windows.SERVICE_FAILURE_ACTIONS_FLAG
|
||||
failureActionsFlag.FailureActionsOnNonCrashFailures = 0
|
||||
|
||||
if err := windows.ChangeServiceConfig2(
|
||||
s.Handle,
|
||||
windows.SERVICE_CONFIG_FAILURE_ACTIONS_FLAG,
|
||||
(*byte)(unsafe.Pointer(&failureActionsFlag)),
|
||||
); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// openLogFile opens a log file with the specified mode on Windows
|
||||
func openLogFile(path string, mode int) (*os.File, error) {
|
||||
if len(path) == 0 {
|
||||
return nil, &os.PathError{Path: path, Op: "open", Err: syscall.ERROR_FILE_NOT_FOUND}
|
||||
|
||||
@@ -1,39 +1,46 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld"
|
||||
)
|
||||
|
||||
const (
|
||||
// maxFailureRequest is the maximum failed queries allowed before an upstream is marked as down.
|
||||
maxFailureRequest = 100
|
||||
maxFailureRequest = 50
|
||||
// checkUpstreamBackoffSleep is the time interval between each upstream checks.
|
||||
checkUpstreamBackoffSleep = 2 * time.Second
|
||||
)
|
||||
|
||||
// upstreamMonitor performs monitoring upstreams health.
|
||||
type upstreamMonitor struct {
|
||||
cfg *ctrld.Config
|
||||
cfg *ctrld.Config
|
||||
logger atomic.Pointer[ctrld.Logger]
|
||||
|
||||
mu sync.Mutex
|
||||
mu sync.RWMutex
|
||||
checking map[string]bool
|
||||
down map[string]bool
|
||||
failureReq map[string]uint64
|
||||
recovered map[string]bool
|
||||
|
||||
// failureTimerActive tracks if a timer is already running for a given upstream.
|
||||
failureTimerActive map[string]bool
|
||||
}
|
||||
|
||||
func newUpstreamMonitor(cfg *ctrld.Config) *upstreamMonitor {
|
||||
// newUpstreamMonitor creates a new upstream monitor instance
|
||||
func newUpstreamMonitor(cfg *ctrld.Config, logger *ctrld.Logger) *upstreamMonitor {
|
||||
um := &upstreamMonitor{
|
||||
cfg: cfg,
|
||||
checking: make(map[string]bool),
|
||||
down: make(map[string]bool),
|
||||
failureReq: make(map[string]uint64),
|
||||
cfg: cfg,
|
||||
checking: make(map[string]bool),
|
||||
down: make(map[string]bool),
|
||||
failureReq: make(map[string]uint64),
|
||||
recovered: make(map[string]bool),
|
||||
failureTimerActive: make(map[string]bool),
|
||||
}
|
||||
um.logger.Store(logger)
|
||||
for n := range cfg.Upstream {
|
||||
upstream := upstreamPrefix + n
|
||||
um.reset(upstream)
|
||||
@@ -42,14 +49,47 @@ func newUpstreamMonitor(cfg *ctrld.Config) *upstreamMonitor {
|
||||
return um
|
||||
}
|
||||
|
||||
// increaseFailureCount increase failed queries count for an upstream by 1.
|
||||
// increaseFailureCount increases failed queries count for an upstream by 1 and logs debug information.
|
||||
// It uses a timer to debounce failure detection, ensuring that an upstream is marked as down
|
||||
// within 10 seconds if failures persist, without spawning duplicate goroutines.
|
||||
func (um *upstreamMonitor) increaseFailureCount(upstream string) {
|
||||
um.mu.Lock()
|
||||
defer um.mu.Unlock()
|
||||
|
||||
if um.recovered[upstream] {
|
||||
um.logger.Load().Debug().Msgf("Upstream %q is recovered, skipping failure count increase", upstream)
|
||||
return
|
||||
}
|
||||
|
||||
um.failureReq[upstream] += 1
|
||||
failedCount := um.failureReq[upstream]
|
||||
um.down[upstream] = failedCount >= maxFailureRequest
|
||||
|
||||
// Log the updated failure count.
|
||||
um.logger.Load().Debug().Msgf("Upstream %q failure count updated to %d", upstream, failedCount)
|
||||
|
||||
// If this is the first failure and no timer is running, start a 10-second timer.
|
||||
if failedCount == 1 && !um.failureTimerActive[upstream] {
|
||||
um.failureTimerActive[upstream] = true
|
||||
go func(upstream string) {
|
||||
time.Sleep(10 * time.Second)
|
||||
um.mu.Lock()
|
||||
defer um.mu.Unlock()
|
||||
// If no success occurred during the 10-second window (i.e. counter remains > 0)
|
||||
// and the upstream is not in a recovered state, mark it as down.
|
||||
if um.failureReq[upstream] > 0 && !um.recovered[upstream] {
|
||||
um.down[upstream] = true
|
||||
um.logger.Load().Warn().Msgf("Upstream %q marked as down after 10 seconds (failure count: %d)", upstream, um.failureReq[upstream])
|
||||
}
|
||||
// Reset the timer flag so that a new timer can be spawned if needed.
|
||||
um.failureTimerActive[upstream] = false
|
||||
}(upstream)
|
||||
}
|
||||
|
||||
// If the failure count quickly reaches the threshold, mark the upstream as down immediately.
|
||||
if failedCount >= maxFailureRequest {
|
||||
um.down[upstream] = true
|
||||
um.logger.Load().Warn().Msgf("Upstream %q marked as down immediately (failure count: %d)", upstream, failedCount)
|
||||
}
|
||||
}
|
||||
|
||||
// isDown reports whether the given upstream is being marked as down.
|
||||
@@ -63,50 +103,28 @@ func (um *upstreamMonitor) isDown(upstream string) bool {
|
||||
// reset marks an upstream as up and set failed queries counter to zero.
|
||||
func (um *upstreamMonitor) reset(upstream string) {
|
||||
um.mu.Lock()
|
||||
defer um.mu.Unlock()
|
||||
|
||||
um.failureReq[upstream] = 0
|
||||
um.down[upstream] = false
|
||||
}
|
||||
|
||||
// checkUpstream checks the given upstream status, periodically sending query to upstream
|
||||
// until successfully. An upstream status/counter will be reset once it becomes reachable.
|
||||
func (um *upstreamMonitor) checkUpstream(upstream string, uc *ctrld.UpstreamConfig) {
|
||||
um.mu.Lock()
|
||||
isChecking := um.checking[upstream]
|
||||
if isChecking {
|
||||
um.mu.Unlock()
|
||||
return
|
||||
}
|
||||
um.checking[upstream] = true
|
||||
um.recovered[upstream] = true
|
||||
um.mu.Unlock()
|
||||
defer func() {
|
||||
go func() {
|
||||
// debounce the recovery to avoid incrementing failure counts already in flight
|
||||
time.Sleep(1 * time.Second)
|
||||
um.mu.Lock()
|
||||
um.checking[upstream] = false
|
||||
um.recovered[upstream] = false
|
||||
um.mu.Unlock()
|
||||
}()
|
||||
|
||||
resolver, err := ctrld.NewResolver(uc)
|
||||
if err != nil {
|
||||
mainLog.Load().Warn().Err(err).Msg("could not check upstream")
|
||||
return
|
||||
}
|
||||
msg := new(dns.Msg)
|
||||
msg.SetQuestion(".", dns.TypeNS)
|
||||
|
||||
check := func() error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel()
|
||||
uc.ReBootstrap()
|
||||
_, err := resolver.Resolve(ctx, msg)
|
||||
return err
|
||||
}
|
||||
for {
|
||||
if err := check(); err == nil {
|
||||
mainLog.Load().Debug().Msgf("upstream %q is online", uc.Endpoint)
|
||||
um.reset(upstream)
|
||||
return
|
||||
}
|
||||
time.Sleep(checkUpstreamBackoffSleep)
|
||||
}
|
||||
}
|
||||
|
||||
// countHealthy returns the number of upstreams in the provided map that are considered healthy.
|
||||
func (um *upstreamMonitor) countHealthy(upstreams []string) int {
|
||||
var count int
|
||||
um.mu.RLock()
|
||||
for _, upstream := range upstreams {
|
||||
if !um.down[upstream] {
|
||||
count++
|
||||
}
|
||||
}
|
||||
um.mu.RUnlock()
|
||||
return count
|
||||
}
|
||||
|
||||
@@ -1,7 +1,13 @@
|
||||
package main
|
||||
|
||||
import "github.com/Control-D-Inc/ctrld/cmd/cli"
|
||||
import (
|
||||
"os"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld/cmd/cli"
|
||||
)
|
||||
|
||||
func main() {
|
||||
cli.Main()
|
||||
// make sure we exit with 0 if there are no errors
|
||||
os.Exit(0)
|
||||
}
|
||||
|
||||
@@ -28,22 +28,24 @@ type AppCallback interface {
|
||||
// Start configures utility with config.toml from provided directory.
|
||||
// This function will block until Stop is called
|
||||
// Check port availability prior to calling it.
|
||||
func (c *Controller) Start(CdUID string, HomeDir string, UpstreamProto string, logLevel int, logPath string) {
|
||||
func (c *Controller) Start(CdUID string, ProvisionID string, CustomHostname string, HomeDir string, UpstreamProto string, logLevel int, logPath string) {
|
||||
if c.stopCh == nil {
|
||||
c.stopCh = make(chan struct{})
|
||||
c.Config = cli.AppConfig{
|
||||
CdUID: CdUID,
|
||||
HomeDir: HomeDir,
|
||||
UpstreamProto: UpstreamProto,
|
||||
Verbose: logLevel,
|
||||
LogPath: logPath,
|
||||
CdUID: CdUID,
|
||||
ProvisionID: ProvisionID,
|
||||
CustomHostname: CustomHostname,
|
||||
HomeDir: HomeDir,
|
||||
UpstreamProto: UpstreamProto,
|
||||
Verbose: logLevel,
|
||||
LogPath: logPath,
|
||||
}
|
||||
appCallback := mapCallback(c.AppCallback)
|
||||
cli.RunMobile(&c.Config, &appCallback, c.stopCh)
|
||||
}
|
||||
}
|
||||
|
||||
// As workaround to avoid circular dependency between cli and ctrld_library module
|
||||
// mapCallback maps the AppCallback interface to cli.AppCallback to avoid circular dependency
|
||||
func mapCallback(callback AppCallback) cli.AppCallback {
|
||||
return cli.AppCallback{
|
||||
HostName: func() string {
|
||||
|
||||
334
config.go
334
config.go
@@ -7,6 +7,7 @@ import (
|
||||
"crypto/x509"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"math/rand"
|
||||
"net"
|
||||
@@ -22,6 +23,7 @@ import (
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/ameshkov/dnsstamps"
|
||||
"github.com/go-playground/validator/v10"
|
||||
"github.com/miekg/dns"
|
||||
"github.com/spf13/viper"
|
||||
@@ -51,14 +53,36 @@ const (
|
||||
FreeDnsDomain = "freedns.controld.com"
|
||||
// FreeDNSBoostrapIP is the IP address of freedns.controld.com.
|
||||
FreeDNSBoostrapIP = "76.76.2.11"
|
||||
// FreeDNSBoostrapIPv6 is the IPv6 address of freedns.controld.com.
|
||||
FreeDNSBoostrapIPv6 = "2606:1a40::11"
|
||||
// PremiumDnsDomain is the domain name of premium ControlD service.
|
||||
PremiumDnsDomain = "dns.controld.com"
|
||||
// PremiumDNSBoostrapIP is the IP address of dns.controld.com.
|
||||
PremiumDNSBoostrapIP = "76.76.2.22"
|
||||
// PremiumDNSBoostrapIPv6 is the IPv6 address of dns.controld.com.
|
||||
PremiumDNSBoostrapIPv6 = "2606:1a40::22"
|
||||
|
||||
// freeDnsDomainDev is the domain name of free ControlD service on dev env.
|
||||
freeDnsDomainDev = "freedns.controld.dev"
|
||||
// freeDNSBoostrapIP is the IP address of freedns.controld.dev.
|
||||
freeDNSBoostrapIP = "176.125.239.11"
|
||||
// freeDNSBoostrapIPv6 is the IPv6 address of freedns.controld.com.
|
||||
freeDNSBoostrapIPv6 = "2606:1a40:f000::11"
|
||||
// premiumDnsDomainDev is the domain name of premium ControlD service on dev env.
|
||||
premiumDnsDomainDev = "dns.controld.dev"
|
||||
// premiumDNSBoostrapIP is the IP address of dns.controld.dev.
|
||||
premiumDNSBoostrapIP = "176.125.239.22"
|
||||
// premiumDNSBoostrapIPv6 is the IPv6 address of dns.controld.dev.
|
||||
premiumDNSBoostrapIPv6 = "2606:1a40:f000::22"
|
||||
|
||||
controlDComDomain = "controld.com"
|
||||
controlDNetDomain = "controld.net"
|
||||
controlDDevDomain = "controld.dev"
|
||||
|
||||
endpointPrefixHTTPS = "https://"
|
||||
endpointPrefixQUIC = "quic://"
|
||||
endpointPrefixH3 = "h3://"
|
||||
endpointPrefixSdns = "sdns://"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -90,6 +114,10 @@ func SetConfigNameWithPath(v *viper.Viper, name, configPath string) {
|
||||
|
||||
// InitConfig initializes default config values for given *viper.Viper instance.
|
||||
func InitConfig(v *viper.Viper, name string) {
|
||||
ctx := context.Background()
|
||||
logger := LoggerFromCtx(ctx)
|
||||
Log(ctx, logger.Debug(), "Config initialization started")
|
||||
|
||||
v.SetDefault("listener", map[string]*ListenerConfig{
|
||||
"0": {
|
||||
IP: "",
|
||||
@@ -128,6 +156,8 @@ func InitConfig(v *viper.Viper, name string) {
|
||||
Timeout: 3000,
|
||||
},
|
||||
})
|
||||
|
||||
Log(ctx, logger.Debug(), "Config initialization completed")
|
||||
}
|
||||
|
||||
// Config represents ctrld supported configuration.
|
||||
@@ -198,7 +228,7 @@ type ServiceConfig struct {
|
||||
CacheFlushDomains []string `mapstructure:"cache_flush_domains" toml:"cache_flush_domains" validate:"max=256"`
|
||||
MaxConcurrentRequests *int `mapstructure:"max_concurrent_requests" toml:"max_concurrent_requests,omitempty" validate:"omitempty,gte=0"`
|
||||
DHCPLeaseFile string `mapstructure:"dhcp_lease_file_path" toml:"dhcp_lease_file_path" validate:"omitempty,file"`
|
||||
DHCPLeaseFileFormat string `mapstructure:"dhcp_lease_file_format" toml:"dhcp_lease_file_format" validate:"required_unless=DHCPLeaseFile '',omitempty,oneof=dnsmasq isc-dhcp"`
|
||||
DHCPLeaseFileFormat string `mapstructure:"dhcp_lease_file_format" toml:"dhcp_lease_file_format" validate:"required_unless=DHCPLeaseFile '',omitempty,oneof=dnsmasq isc-dhcp kea-dhcp4"`
|
||||
DiscoverMDNS *bool `mapstructure:"discover_mdns" toml:"discover_mdns,omitempty"`
|
||||
DiscoverARP *bool `mapstructure:"discover_arp" toml:"discover_arp,omitempty"`
|
||||
DiscoverDHCP *bool `mapstructure:"discover_dhcp" toml:"discover_dhcp,omitempty"`
|
||||
@@ -211,6 +241,8 @@ type ServiceConfig struct {
|
||||
DnsWatchdogEnabled *bool `mapstructure:"dns_watchdog_enabled" toml:"dns_watchdog_enabled,omitempty"`
|
||||
DnsWatchdogInvterval *time.Duration `mapstructure:"dns_watchdog_interval" toml:"dns_watchdog_interval,omitempty"`
|
||||
RefetchTime *int `mapstructure:"refetch_time" toml:"refetch_time,omitempty"`
|
||||
ForceRefetchWaitTime *int `mapstructure:"force_refetch_wait_time" toml:"force_refetch_wait_time,omitempty"`
|
||||
LeakOnUpstreamFailure *bool `mapstructure:"leak_on_upstream_failure" toml:"leak_on_upstream_failure,omitempty"`
|
||||
Daemon bool `mapstructure:"-" toml:"-"`
|
||||
AllocateIP bool `mapstructure:"-" toml:"-"`
|
||||
}
|
||||
@@ -225,7 +257,7 @@ type NetworkConfig struct {
|
||||
// UpstreamConfig specifies configuration for upstreams that ctrld will forward requests to.
|
||||
type UpstreamConfig struct {
|
||||
Name string `mapstructure:"name" toml:"name,omitempty"`
|
||||
Type string `mapstructure:"type" toml:"type,omitempty" validate:"oneof=doh doh3 dot doq os legacy"`
|
||||
Type string `mapstructure:"type" toml:"type,omitempty" validate:"oneof=doh doh3 dot doq os legacy sdns ''"`
|
||||
Endpoint string `mapstructure:"endpoint" toml:"endpoint,omitempty"`
|
||||
BootstrapIP string `mapstructure:"bootstrap_ip" toml:"bootstrap_ip,omitempty"`
|
||||
Domain string `mapstructure:"-" toml:"-"`
|
||||
@@ -252,6 +284,7 @@ type UpstreamConfig struct {
|
||||
http3RoundTripper6 http.RoundTripper
|
||||
certPool *x509.CertPool
|
||||
u *url.URL
|
||||
fallbackOnce sync.Once
|
||||
uid string
|
||||
}
|
||||
|
||||
@@ -282,14 +315,20 @@ func (lc *ListenerConfig) IsDirectDnsListener() bool {
|
||||
}
|
||||
}
|
||||
|
||||
// MatchingConfig defines the configuration for rule matching behavior
|
||||
type MatchingConfig struct {
|
||||
Order []string `mapstructure:"order" toml:"order,omitempty"`
|
||||
}
|
||||
|
||||
// ListenerPolicyConfig specifies the policy rules for ctrld to filter incoming requests.
|
||||
type ListenerPolicyConfig struct {
|
||||
Name string `mapstructure:"name" toml:"name,omitempty"`
|
||||
Networks []Rule `mapstructure:"networks" toml:"networks,omitempty,inline,multiline" validate:"dive,len=1"`
|
||||
Rules []Rule `mapstructure:"rules" toml:"rules,omitempty,inline,multiline" validate:"dive,len=1"`
|
||||
Macs []Rule `mapstructure:"macs" toml:"macs,omitempty,inline,multiline" validate:"dive,len=1"`
|
||||
FailoverRcodes []string `mapstructure:"failover_rcodes" toml:"failover_rcodes,omitempty" validate:"dive,dnsrcode"`
|
||||
FailoverRcodeNumbers []int `mapstructure:"-" toml:"-"`
|
||||
Name string `mapstructure:"name" toml:"name,omitempty"`
|
||||
Networks []Rule `mapstructure:"networks" toml:"networks,omitempty,inline,multiline" validate:"dive,len=1"`
|
||||
Rules []Rule `mapstructure:"rules" toml:"rules,omitempty,inline,multiline" validate:"dive,len=1"`
|
||||
Macs []Rule `mapstructure:"macs" toml:"macs,omitempty,inline,multiline" validate:"dive,len=1"`
|
||||
FailoverRcodes []string `mapstructure:"failover_rcodes" toml:"failover_rcodes,omitempty" validate:"dive,dnsrcode"`
|
||||
FailoverRcodeNumbers []int `mapstructure:"-" toml:"-"`
|
||||
Matching *MatchingConfig `mapstructure:"-" toml:"-"`
|
||||
}
|
||||
|
||||
// Rule is a map from source to list of upstreams.
|
||||
@@ -298,11 +337,15 @@ type ListenerPolicyConfig struct {
|
||||
type Rule map[string][]string
|
||||
|
||||
// Init initialized necessary values for an UpstreamConfig.
|
||||
func (uc *UpstreamConfig) Init() {
|
||||
func (uc *UpstreamConfig) Init(ctx context.Context) {
|
||||
logger := LoggerFromCtx(ctx)
|
||||
if err := uc.initDnsStamps(); err != nil {
|
||||
logger.Fatal().Err(err).Msg("Invalid dns stamps")
|
||||
}
|
||||
uc.initDoHScheme()
|
||||
uc.uid = upstreamUID()
|
||||
uc.uid = upstreamUID(ctx)
|
||||
if u, err := url.Parse(uc.Endpoint); err == nil {
|
||||
uc.Domain = u.Host
|
||||
uc.Domain = u.Hostname()
|
||||
switch uc.Type {
|
||||
case ResolverTypeDOH, ResolverTypeDOH3:
|
||||
uc.u = u
|
||||
@@ -320,6 +363,9 @@ func (uc *UpstreamConfig) Init() {
|
||||
}
|
||||
}
|
||||
if uc.IPStack == "" {
|
||||
// Set default IP stack based on upstream type
|
||||
// Control-D upstreams use split stack for better IPv4/IPv6 handling,
|
||||
// while other upstreams use both stacks for maximum compatibility
|
||||
if uc.IsControlD() {
|
||||
uc.IPStack = IpStackSplit
|
||||
} else {
|
||||
@@ -328,6 +374,15 @@ func (uc *UpstreamConfig) Init() {
|
||||
}
|
||||
}
|
||||
|
||||
// VerifyMsg creates and returns a new DNS message could be used for testing upstream health.
|
||||
func (uc *UpstreamConfig) VerifyMsg() *dns.Msg {
|
||||
msg := new(dns.Msg)
|
||||
msg.RecursionDesired = true
|
||||
msg.SetQuestion(".", dns.TypeNS)
|
||||
msg.SetEdns0(4096, false) // ensure handling of large DNS response
|
||||
return msg
|
||||
}
|
||||
|
||||
// VerifyDomain returns the domain name that could be resolved by the upstream endpoint.
|
||||
// It returns empty for non-ControlD upstream endpoint.
|
||||
func (uc *UpstreamConfig) VerifyDomain() string {
|
||||
@@ -390,24 +445,23 @@ func (uc *UpstreamConfig) SetCertPool(cp *x509.CertPool) {
|
||||
uc.certPool = cp
|
||||
}
|
||||
|
||||
// SetupBootstrapIP manually find all available IPs of the upstream.
|
||||
// The first usable IP will be used as bootstrap IP of the upstream.
|
||||
func (uc *UpstreamConfig) SetupBootstrapIP() {
|
||||
uc.setupBootstrapIP(true)
|
||||
}
|
||||
|
||||
// UID returns the unique identifier of the upstream.
|
||||
func (uc *UpstreamConfig) UID() string {
|
||||
return uc.uid
|
||||
}
|
||||
|
||||
// SetupBootstrapIP manually find all available IPs of the upstream.
|
||||
// The first usable IP will be used as bootstrap IP of the upstream.
|
||||
func (uc *UpstreamConfig) setupBootstrapIP(withBootstrapDNS bool) {
|
||||
// SetupBootstrapIP sets up bootstrap IPs for the upstream config.
|
||||
// The setup process will block until there's usable IPs found.
|
||||
func (uc *UpstreamConfig) SetupBootstrapIP(ctx context.Context) {
|
||||
logger := LoggerFromCtx(ctx)
|
||||
Log(ctx, logger.Debug(), "Setting up bootstrap IPs for upstream: %s", uc.Name)
|
||||
|
||||
b := backoff.NewBackoff("setupBootstrapIP", func(format string, args ...any) {}, 10*time.Second)
|
||||
isControlD := uc.IsControlD()
|
||||
nss := initDefaultOsResolver(ctx)
|
||||
for {
|
||||
uc.bootstrapIPs = lookupIP(uc.Domain, uc.Timeout, withBootstrapDNS)
|
||||
Log(ctx, logger.Debug(), "Looking up bootstrap IPs for domain: %s", uc.Domain)
|
||||
uc.bootstrapIPs = lookupIP(ctx, uc.Domain, uc.Timeout, nss)
|
||||
// For ControlD upstream, the bootstrap IPs could not be RFC 1918 addresses,
|
||||
// filtering them out here to prevent weird behavior.
|
||||
if isControlD {
|
||||
@@ -420,11 +474,20 @@ func (uc *UpstreamConfig) setupBootstrapIP(withBootstrapDNS bool) {
|
||||
}
|
||||
}
|
||||
uc.bootstrapIPs = uc.bootstrapIPs[:n]
|
||||
if len(uc.bootstrapIPs) == 0 {
|
||||
uc.bootstrapIPs = bootstrapIPsFromControlDDomain(uc.Domain)
|
||||
logger.Warn().Msgf("No record found for %q, lookup from direct ip table", uc.Domain)
|
||||
}
|
||||
}
|
||||
if len(uc.bootstrapIPs) == 0 {
|
||||
logger.Warn().Msgf("No record found for %q, using bootstrap server: %s", uc.Domain, PremiumDNSBoostrapIP)
|
||||
uc.bootstrapIPs = lookupIP(ctx, uc.Domain, uc.Timeout, []string{net.JoinHostPort(PremiumDNSBoostrapIP, "53")})
|
||||
|
||||
}
|
||||
if len(uc.bootstrapIPs) > 0 {
|
||||
break
|
||||
}
|
||||
ProxyLogger.Load().Warn().Msg("could not resolve bootstrap IPs, retrying...")
|
||||
logger.Warn().Msg("Could not resolve bootstrap ips, retrying...")
|
||||
b.BackOff(context.Background(), errors.New("no bootstrap IPs"))
|
||||
}
|
||||
for _, ip := range uc.bootstrapIPs {
|
||||
@@ -434,11 +497,12 @@ func (uc *UpstreamConfig) setupBootstrapIP(withBootstrapDNS bool) {
|
||||
uc.bootstrapIPs4 = append(uc.bootstrapIPs4, ip)
|
||||
}
|
||||
}
|
||||
ProxyLogger.Load().Debug().Msgf("bootstrap IPs: %v", uc.bootstrapIPs)
|
||||
logger.Debug().Msgf("Bootstrap ips: %v", uc.bootstrapIPs)
|
||||
Log(ctx, logger.Debug(), "Bootstrap IP setup completed for upstream: %s", uc.Name)
|
||||
}
|
||||
|
||||
// ReBootstrap re-setup the bootstrap IP and the transport.
|
||||
func (uc *UpstreamConfig) ReBootstrap() {
|
||||
func (uc *UpstreamConfig) ReBootstrap(ctx context.Context) {
|
||||
switch uc.Type {
|
||||
case ResolverTypeDOH, ResolverTypeDOH3:
|
||||
default:
|
||||
@@ -446,7 +510,8 @@ func (uc *UpstreamConfig) ReBootstrap() {
|
||||
}
|
||||
_, _, _ = uc.g.Do("ReBootstrap", func() (any, error) {
|
||||
if uc.rebootstrap.CompareAndSwap(false, true) {
|
||||
ProxyLogger.Load().Debug().Msg("re-bootstrapping upstream ip")
|
||||
logger := LoggerFromCtx(ctx)
|
||||
Log(ctx, logger.Debug(), "Re-bootstrapping upstream: %s", uc.Name)
|
||||
}
|
||||
return true, nil
|
||||
})
|
||||
@@ -454,35 +519,35 @@ func (uc *UpstreamConfig) ReBootstrap() {
|
||||
|
||||
// SetupTransport initializes the network transport used to connect to upstream server.
|
||||
// For now, only DoH upstream is supported.
|
||||
func (uc *UpstreamConfig) SetupTransport() {
|
||||
func (uc *UpstreamConfig) SetupTransport(ctx context.Context) {
|
||||
switch uc.Type {
|
||||
case ResolverTypeDOH:
|
||||
uc.setupDOHTransport()
|
||||
uc.setupDOHTransport(ctx)
|
||||
case ResolverTypeDOH3:
|
||||
uc.setupDOH3Transport()
|
||||
uc.setupDOH3Transport(ctx)
|
||||
}
|
||||
}
|
||||
|
||||
func (uc *UpstreamConfig) setupDOHTransport() {
|
||||
func (uc *UpstreamConfig) setupDOHTransport(ctx context.Context) {
|
||||
switch uc.IPStack {
|
||||
case IpStackBoth, "":
|
||||
uc.transport = uc.newDOHTransport(uc.bootstrapIPs)
|
||||
uc.transport = uc.newDOHTransport(ctx, uc.bootstrapIPs)
|
||||
case IpStackV4:
|
||||
uc.transport = uc.newDOHTransport(uc.bootstrapIPs4)
|
||||
uc.transport = uc.newDOHTransport(ctx, uc.bootstrapIPs4)
|
||||
case IpStackV6:
|
||||
uc.transport = uc.newDOHTransport(uc.bootstrapIPs6)
|
||||
uc.transport = uc.newDOHTransport(ctx, uc.bootstrapIPs6)
|
||||
case IpStackSplit:
|
||||
uc.transport4 = uc.newDOHTransport(uc.bootstrapIPs4)
|
||||
if hasIPv6() {
|
||||
uc.transport6 = uc.newDOHTransport(uc.bootstrapIPs6)
|
||||
uc.transport4 = uc.newDOHTransport(ctx, uc.bootstrapIPs4)
|
||||
if HasIPv6(ctx) {
|
||||
uc.transport6 = uc.newDOHTransport(ctx, uc.bootstrapIPs6)
|
||||
} else {
|
||||
uc.transport6 = uc.transport4
|
||||
}
|
||||
uc.transport = uc.newDOHTransport(uc.bootstrapIPs)
|
||||
uc.transport = uc.newDOHTransport(ctx, uc.bootstrapIPs)
|
||||
}
|
||||
}
|
||||
|
||||
func (uc *UpstreamConfig) newDOHTransport(addrs []string) *http.Transport {
|
||||
func (uc *UpstreamConfig) newDOHTransport(ctx context.Context, addrs []string) *http.Transport {
|
||||
transport := http.DefaultTransport.(*http.Transport).Clone()
|
||||
transport.MaxIdleConnsPerHost = 100
|
||||
transport.TLSClientConfig = &tls.Config{
|
||||
@@ -502,12 +567,13 @@ func (uc *UpstreamConfig) newDOHTransport(addrs []string) *http.Transport {
|
||||
dialerTimeoutMs = uc.Timeout
|
||||
}
|
||||
dialerTimeout := time.Duration(dialerTimeoutMs) * time.Millisecond
|
||||
logger := LoggerFromCtx(ctx)
|
||||
transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
_, port, _ := net.SplitHostPort(addr)
|
||||
if uc.BootstrapIP != "" {
|
||||
dialer := net.Dialer{Timeout: dialerTimeout, KeepAlive: dialerTimeout}
|
||||
addr := net.JoinHostPort(uc.BootstrapIP, port)
|
||||
Log(ctx, ProxyLogger.Load().Debug(), "sending doh request to: %s", addr)
|
||||
Log(ctx, logger.Debug(), "Sending doh request to: %s", addr)
|
||||
return dialer.DialContext(ctx, network, addr)
|
||||
}
|
||||
pd := &ctrldnet.ParallelDialer{}
|
||||
@@ -517,11 +583,11 @@ func (uc *UpstreamConfig) newDOHTransport(addrs []string) *http.Transport {
|
||||
for i := range addrs {
|
||||
dialAddrs[i] = net.JoinHostPort(addrs[i], port)
|
||||
}
|
||||
conn, err := pd.DialContext(ctx, network, dialAddrs)
|
||||
conn, err := pd.DialContext(ctx, network, dialAddrs, logger.Logger)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
Log(ctx, ProxyLogger.Load().Debug(), "sending doh request to: %s", conn.RemoteAddr())
|
||||
Log(ctx, logger.Debug(), "Sending doh request to: %s", conn.RemoteAddr())
|
||||
return conn, nil
|
||||
}
|
||||
runtime.SetFinalizer(transport, func(transport *http.Transport) {
|
||||
@@ -531,16 +597,20 @@ func (uc *UpstreamConfig) newDOHTransport(addrs []string) *http.Transport {
|
||||
}
|
||||
|
||||
// Ping warms up the connection to DoH/DoH3 upstream.
|
||||
func (uc *UpstreamConfig) Ping() {
|
||||
_ = uc.ping()
|
||||
func (uc *UpstreamConfig) Ping(ctx context.Context) {
|
||||
if err := uc.ping(ctx); err != nil {
|
||||
logger := LoggerFromCtx(ctx)
|
||||
logger.Debug().Err(err).Msgf("Upstream ping failed: %s", uc.Endpoint)
|
||||
_ = uc.FallbackToDirectIP(ctx)
|
||||
}
|
||||
}
|
||||
|
||||
// ErrorPing is like Ping, but return an error if any.
|
||||
func (uc *UpstreamConfig) ErrorPing() error {
|
||||
return uc.ping()
|
||||
func (uc *UpstreamConfig) ErrorPing(ctx context.Context) error {
|
||||
return uc.ping(ctx)
|
||||
}
|
||||
|
||||
func (uc *UpstreamConfig) ping() error {
|
||||
func (uc *UpstreamConfig) ping(ctx context.Context) error {
|
||||
switch uc.Type {
|
||||
case ResolverTypeDOH, ResolverTypeDOH3:
|
||||
default:
|
||||
@@ -569,12 +639,11 @@ func (uc *UpstreamConfig) ping() error {
|
||||
for _, typ := range []uint16{dns.TypeA, dns.TypeAAAA} {
|
||||
switch uc.Type {
|
||||
case ResolverTypeDOH:
|
||||
|
||||
if err := ping(uc.dohTransport(typ)); err != nil {
|
||||
if err := ping(uc.dohTransport(ctx, typ)); err != nil {
|
||||
return err
|
||||
}
|
||||
case ResolverTypeDOH3:
|
||||
if err := ping(uc.doh3Transport(typ)); err != nil {
|
||||
if err := ping(uc.doh3Transport(ctx, typ)); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
@@ -609,12 +678,12 @@ func (uc *UpstreamConfig) isNextDNS() bool {
|
||||
return domain == "dns.nextdns.io"
|
||||
}
|
||||
|
||||
func (uc *UpstreamConfig) dohTransport(dnsType uint16) http.RoundTripper {
|
||||
func (uc *UpstreamConfig) dohTransport(ctx context.Context, dnsType uint16) http.RoundTripper {
|
||||
uc.transportOnce.Do(func() {
|
||||
uc.SetupTransport()
|
||||
uc.SetupTransport(ctx)
|
||||
})
|
||||
if uc.rebootstrap.CompareAndSwap(true, false) {
|
||||
uc.SetupTransport()
|
||||
uc.SetupTransport(ctx)
|
||||
}
|
||||
switch uc.IPStack {
|
||||
case IpStackBoth, IpStackV4, IpStackV6:
|
||||
@@ -630,7 +699,7 @@ func (uc *UpstreamConfig) dohTransport(dnsType uint16) http.RoundTripper {
|
||||
return uc.transport
|
||||
}
|
||||
|
||||
func (uc *UpstreamConfig) bootstrapIPForDNSType(dnsType uint16) string {
|
||||
func (uc *UpstreamConfig) bootstrapIPForDNSType(ctx context.Context, dnsType uint16) string {
|
||||
switch uc.IPStack {
|
||||
case IpStackBoth:
|
||||
return pick(uc.bootstrapIPs)
|
||||
@@ -643,7 +712,7 @@ func (uc *UpstreamConfig) bootstrapIPForDNSType(dnsType uint16) string {
|
||||
case dns.TypeA:
|
||||
return pick(uc.bootstrapIPs4)
|
||||
default:
|
||||
if hasIPv6() {
|
||||
if HasIPv6(ctx) {
|
||||
return pick(uc.bootstrapIPs6)
|
||||
}
|
||||
return pick(uc.bootstrapIPs4)
|
||||
@@ -652,7 +721,7 @@ func (uc *UpstreamConfig) bootstrapIPForDNSType(dnsType uint16) string {
|
||||
return pick(uc.bootstrapIPs)
|
||||
}
|
||||
|
||||
func (uc *UpstreamConfig) netForDNSType(dnsType uint16) (string, string) {
|
||||
func (uc *UpstreamConfig) netForDNSType(ctx context.Context, dnsType uint16) (string, string) {
|
||||
switch uc.IPStack {
|
||||
case IpStackBoth:
|
||||
return "tcp-tls", "udp"
|
||||
@@ -665,7 +734,7 @@ func (uc *UpstreamConfig) netForDNSType(dnsType uint16) (string, string) {
|
||||
case dns.TypeA:
|
||||
return "tcp4-tls", "udp4"
|
||||
default:
|
||||
if hasIPv6() {
|
||||
if HasIPv6(ctx) {
|
||||
return "tcp6-tls", "udp6"
|
||||
}
|
||||
return "tcp4-tls", "udp4"
|
||||
@@ -676,24 +745,117 @@ func (uc *UpstreamConfig) netForDNSType(dnsType uint16) (string, string) {
|
||||
|
||||
// initDoHScheme initializes the endpoint scheme for DoH/DoH3 upstream if not present.
|
||||
func (uc *UpstreamConfig) initDoHScheme() {
|
||||
if strings.HasPrefix(uc.Endpoint, endpointPrefixH3) && uc.Type == "" {
|
||||
uc.Type = ResolverTypeDOH3
|
||||
}
|
||||
switch uc.Type {
|
||||
case ResolverTypeDOH, ResolverTypeDOH3:
|
||||
case ResolverTypeDOH:
|
||||
case ResolverTypeDOH3:
|
||||
if after, found := strings.CutPrefix(uc.Endpoint, endpointPrefixH3); found {
|
||||
uc.Endpoint = endpointPrefixHTTPS + after
|
||||
}
|
||||
default:
|
||||
return
|
||||
}
|
||||
if !strings.HasPrefix(uc.Endpoint, "https://") {
|
||||
uc.Endpoint = "https://" + uc.Endpoint
|
||||
if !strings.HasPrefix(uc.Endpoint, endpointPrefixHTTPS) {
|
||||
uc.Endpoint = endpointPrefixHTTPS + uc.Endpoint
|
||||
}
|
||||
}
|
||||
|
||||
// initDnsStamps initializes upstream config based on encoded DNS Stamps Endpoint.
|
||||
func (uc *UpstreamConfig) initDnsStamps() error {
|
||||
if strings.HasPrefix(uc.Endpoint, endpointPrefixSdns) && uc.Type == "" {
|
||||
uc.Type = ResolverTypeSDNS
|
||||
}
|
||||
if uc.Type != ResolverTypeSDNS {
|
||||
return nil
|
||||
}
|
||||
sdns, err := dnsstamps.NewServerStampFromString(uc.Endpoint)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ip, port, _ := net.SplitHostPort(sdns.ServerAddrStr)
|
||||
providerName, port2, _ := net.SplitHostPort(sdns.ProviderName)
|
||||
if port2 != "" {
|
||||
port = port2
|
||||
}
|
||||
if providerName == "" {
|
||||
providerName = sdns.ProviderName
|
||||
}
|
||||
switch sdns.Proto {
|
||||
case dnsstamps.StampProtoTypeDoH:
|
||||
uc.Type = ResolverTypeDOH
|
||||
host := sdns.ProviderName
|
||||
if port != "" && port != defaultPortFor(uc.Type) {
|
||||
host = net.JoinHostPort(providerName, port)
|
||||
}
|
||||
uc.Endpoint = "https://" + host + sdns.Path
|
||||
case dnsstamps.StampProtoTypeTLS:
|
||||
uc.Type = ResolverTypeDOT
|
||||
uc.Endpoint = net.JoinHostPort(providerName, port)
|
||||
case dnsstamps.StampProtoTypeDoQ:
|
||||
uc.Type = ResolverTypeDOQ
|
||||
uc.Endpoint = net.JoinHostPort(providerName, port)
|
||||
case dnsstamps.StampProtoTypePlain:
|
||||
uc.Type = ResolverTypeLegacy
|
||||
uc.Endpoint = sdns.ServerAddrStr
|
||||
default:
|
||||
return fmt.Errorf("unsupported stamp protocol %q", sdns.Proto)
|
||||
}
|
||||
uc.BootstrapIP = ip
|
||||
return nil
|
||||
}
|
||||
|
||||
// Context returns a new context with timeout set from upstream config.
|
||||
func (uc *UpstreamConfig) Context(ctx context.Context) (context.Context, context.CancelFunc) {
|
||||
if uc.Timeout > 0 {
|
||||
return context.WithTimeout(ctx, time.Millisecond*time.Duration(uc.Timeout))
|
||||
}
|
||||
return context.WithCancel(ctx)
|
||||
}
|
||||
|
||||
// FallbackToDirectIP changes ControlD upstream endpoint to use direct IP instead of domain.
|
||||
func (uc *UpstreamConfig) FallbackToDirectIP(ctx context.Context) bool {
|
||||
if !uc.IsControlD() {
|
||||
return false
|
||||
}
|
||||
if uc.u == nil || uc.Domain == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
done := false
|
||||
uc.fallbackOnce.Do(func() {
|
||||
var ip string
|
||||
switch {
|
||||
case dns.IsSubDomain(PremiumDnsDomain, uc.Domain):
|
||||
ip = PremiumDNSBoostrapIP
|
||||
case dns.IsSubDomain(FreeDnsDomain, uc.Domain):
|
||||
ip = FreeDNSBoostrapIP
|
||||
default:
|
||||
return
|
||||
}
|
||||
logger := LoggerFromCtx(ctx)
|
||||
Log(ctx, logger.Warn(), "Using direct IP for %q: %s", uc.Endpoint, ip)
|
||||
uc.u.Host = ip
|
||||
done = true
|
||||
})
|
||||
return done
|
||||
}
|
||||
|
||||
// Init initialized necessary values for an ListenerConfig.
|
||||
func (lc *ListenerConfig) Init() {
|
||||
logger := LoggerFromCtx(context.Background())
|
||||
Log(context.Background(), logger.Debug(), "Initializing listener config")
|
||||
|
||||
if lc.Policy != nil {
|
||||
lc.Policy.FailoverRcodeNumbers = make([]int, len(lc.Policy.FailoverRcodes))
|
||||
for i, rcode := range lc.Policy.FailoverRcodes {
|
||||
lc.Policy.FailoverRcodeNumbers[i] = dnsrcode.FromString(rcode)
|
||||
}
|
||||
Log(context.Background(), logger.Debug(), "Listener policy initialized with %d failover rcodes", len(lc.Policy.FailoverRcodes))
|
||||
}
|
||||
|
||||
Log(context.Background(), logger.Debug(), "Listener config initialization completed")
|
||||
}
|
||||
|
||||
// ValidateConfig validates the given config.
|
||||
@@ -738,6 +900,23 @@ func upstreamConfigStructLevelValidation(sl validator.StructLevel) {
|
||||
return
|
||||
}
|
||||
|
||||
// Empty type is ok only for endpoints starts with "h3://" and "sdns://".
|
||||
if uc.Type == "" && !strings.HasPrefix(uc.Endpoint, endpointPrefixH3) && !strings.HasPrefix(uc.Endpoint, endpointPrefixSdns) {
|
||||
sl.ReportError(uc.Endpoint, "type", "type", "oneof", "doh doh3 dot doq os legacy sdns")
|
||||
return
|
||||
}
|
||||
|
||||
// initDoHScheme/initDnsStamps may change upstreams information,
|
||||
// so restoring changed values after validation to keep original one.
|
||||
defer func(ep, typ string) {
|
||||
uc.Endpoint = ep
|
||||
uc.Type = typ
|
||||
}(uc.Endpoint, uc.Type)
|
||||
|
||||
if err := uc.initDnsStamps(); err != nil {
|
||||
sl.ReportError(uc.Endpoint, "endpoint", "Endpoint", "http_url", "")
|
||||
return
|
||||
}
|
||||
uc.initDoHScheme()
|
||||
// DoH/DoH3 requires endpoint is an HTTP url.
|
||||
if uc.Type == ResolverTypeDOH || uc.Type == ResolverTypeDOH3 {
|
||||
@@ -767,13 +946,19 @@ func defaultPortFor(typ string) string {
|
||||
// - If endpoint is an IP address -> ResolverTypeLegacy
|
||||
// - If endpoint starts with "https://" -> ResolverTypeDOH
|
||||
// - If endpoint starts with "quic://" -> ResolverTypeDOQ
|
||||
// - If endpoint starts with "h3://" -> ResolverTypeDOH3
|
||||
// - If endpoint starts with "sdns://" -> ResolverTypeSDNS
|
||||
// - For anything else -> ResolverTypeDOT
|
||||
func ResolverTypeFromEndpoint(endpoint string) string {
|
||||
switch {
|
||||
case strings.HasPrefix(endpoint, "https://"):
|
||||
case strings.HasPrefix(endpoint, endpointPrefixHTTPS):
|
||||
return ResolverTypeDOH
|
||||
case strings.HasPrefix(endpoint, "quic://"):
|
||||
case strings.HasPrefix(endpoint, endpointPrefixQUIC):
|
||||
return ResolverTypeDOQ
|
||||
case strings.HasPrefix(endpoint, endpointPrefixH3):
|
||||
return ResolverTypeDOH3
|
||||
case strings.HasPrefix(endpoint, endpointPrefixSdns):
|
||||
return ResolverTypeSDNS
|
||||
}
|
||||
host := endpoint
|
||||
if strings.Contains(endpoint, ":") {
|
||||
@@ -790,13 +975,38 @@ func pick(s []string) string {
|
||||
}
|
||||
|
||||
// upstreamUID generates an unique identifier for an upstream.
|
||||
func upstreamUID() string {
|
||||
func upstreamUID(ctx context.Context) string {
|
||||
logger := LoggerFromCtx(ctx)
|
||||
b := make([]byte, 4)
|
||||
for {
|
||||
if _, err := crand.Read(b); err != nil {
|
||||
ProxyLogger.Load().Warn().Err(err).Msg("could not generate uid for upstream, retrying...")
|
||||
logger.Warn().Err(err).Msg("Could not generate uid for upstream, retrying...")
|
||||
continue
|
||||
}
|
||||
return hex.EncodeToString(b)
|
||||
}
|
||||
}
|
||||
|
||||
// String returns a string representation of the UpstreamConfig for logging.
|
||||
func (uc *UpstreamConfig) String() string {
|
||||
if uc == nil {
|
||||
return "<nil>"
|
||||
}
|
||||
return fmt.Sprintf("{name: %q, type: %q, endpoint: %q, bootstrap_ip: %q, domain: %q, ip_stack: %q}",
|
||||
uc.Name, uc.Type, uc.Endpoint, uc.BootstrapIP, uc.Domain, uc.IPStack)
|
||||
}
|
||||
|
||||
// bootstrapIPsFromControlDDomain returns bootstrap IPs for ControlD domain.
|
||||
func bootstrapIPsFromControlDDomain(domain string) []string {
|
||||
switch {
|
||||
case dns.IsSubDomain(PremiumDnsDomain, domain):
|
||||
return []string{PremiumDNSBoostrapIP, PremiumDNSBoostrapIPv6}
|
||||
case dns.IsSubDomain(FreeDnsDomain, domain):
|
||||
return []string{FreeDNSBoostrapIP, FreeDNSBoostrapIPv6}
|
||||
case dns.IsSubDomain(premiumDnsDomainDev, domain):
|
||||
return []string{premiumDNSBoostrapIP, premiumDNSBoostrapIPv6}
|
||||
case dns.IsSubDomain(freeDnsDomainDev, domain):
|
||||
return []string{freeDNSBoostrapIP, freeDNSBoostrapIPv6}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package ctrld
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/url"
|
||||
"testing"
|
||||
|
||||
@@ -8,24 +9,49 @@ import (
|
||||
)
|
||||
|
||||
func TestUpstreamConfig_SetupBootstrapIP(t *testing.T) {
|
||||
uc := &UpstreamConfig{
|
||||
Name: "test",
|
||||
Type: ResolverTypeDOH,
|
||||
Endpoint: "https://freedns.controld.com/p2",
|
||||
Timeout: 5000,
|
||||
tests := []struct {
|
||||
name string
|
||||
uc *UpstreamConfig
|
||||
}{
|
||||
{
|
||||
name: "doh/doh3",
|
||||
uc: &UpstreamConfig{
|
||||
Name: "doh",
|
||||
Type: ResolverTypeDOH,
|
||||
Endpoint: "https://freedns.controld.com/p2",
|
||||
Timeout: 5000,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "doq/dot",
|
||||
uc: &UpstreamConfig{
|
||||
Name: "dot",
|
||||
Type: ResolverTypeDOT,
|
||||
Endpoint: "p2.freedns.controld.com",
|
||||
Timeout: 5000,
|
||||
},
|
||||
},
|
||||
}
|
||||
uc.Init()
|
||||
uc.setupBootstrapIP(false)
|
||||
if len(uc.bootstrapIPs) == 0 {
|
||||
t.Log(nameservers())
|
||||
t.Fatal("could not bootstrap ip without bootstrap DNS")
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// Enable parallel tests once https://github.com/microsoft/wmi/issues/165 fixed.
|
||||
// t.Parallel()
|
||||
tc.uc.Init(context.Background())
|
||||
tc.uc.SetupBootstrapIP(context.Background())
|
||||
if len(tc.uc.bootstrapIPs) == 0 {
|
||||
t.Log(defaultNameservers(context.Background()))
|
||||
t.Fatalf("could not bootstrap ip: %s", tc.uc.String())
|
||||
}
|
||||
})
|
||||
}
|
||||
t.Log(uc)
|
||||
|
||||
}
|
||||
|
||||
func TestUpstreamConfig_Init(t *testing.T) {
|
||||
u1, _ := url.Parse("https://example.com")
|
||||
u2, _ := url.Parse("https://example.com?k=v")
|
||||
u3, _ := url.Parse("https://freedns.controld.com/p1")
|
||||
tests := []struct {
|
||||
name string
|
||||
uc *UpstreamConfig
|
||||
@@ -178,13 +204,159 @@ func TestUpstreamConfig_Init(t *testing.T) {
|
||||
u: u2,
|
||||
},
|
||||
},
|
||||
{
|
||||
"h3",
|
||||
&UpstreamConfig{
|
||||
Name: "doh3",
|
||||
Type: "doh3",
|
||||
Endpoint: "h3://example.com",
|
||||
BootstrapIP: "",
|
||||
Domain: "",
|
||||
Timeout: 0,
|
||||
},
|
||||
&UpstreamConfig{
|
||||
Name: "doh3",
|
||||
Type: "doh3",
|
||||
Endpoint: "https://example.com",
|
||||
BootstrapIP: "",
|
||||
Domain: "example.com",
|
||||
Timeout: 0,
|
||||
IPStack: IpStackBoth,
|
||||
u: u1,
|
||||
},
|
||||
},
|
||||
{
|
||||
"h3 without type",
|
||||
&UpstreamConfig{
|
||||
Name: "doh3",
|
||||
Endpoint: "h3://example.com",
|
||||
BootstrapIP: "",
|
||||
Domain: "",
|
||||
Timeout: 0,
|
||||
},
|
||||
&UpstreamConfig{
|
||||
Name: "doh3",
|
||||
Type: "doh3",
|
||||
Endpoint: "https://example.com",
|
||||
BootstrapIP: "",
|
||||
Domain: "example.com",
|
||||
Timeout: 0,
|
||||
IPStack: IpStackBoth,
|
||||
u: u1,
|
||||
},
|
||||
},
|
||||
{
|
||||
"sdns -> doh",
|
||||
&UpstreamConfig{
|
||||
Name: "sdns",
|
||||
Type: "sdns",
|
||||
Endpoint: "sdns://AgMAAAAAAAAACjc2Ljc2LjIuMTEAFGZyZWVkbnMuY29udHJvbGQuY29tAy9wMQ",
|
||||
BootstrapIP: "",
|
||||
Domain: "",
|
||||
Timeout: 0,
|
||||
IPStack: IpStackBoth,
|
||||
},
|
||||
&UpstreamConfig{
|
||||
Name: "sdns",
|
||||
Type: "doh",
|
||||
Endpoint: "https://freedns.controld.com/p1",
|
||||
BootstrapIP: "76.76.2.11",
|
||||
Domain: "freedns.controld.com",
|
||||
Timeout: 0,
|
||||
IPStack: IpStackBoth,
|
||||
u: u3,
|
||||
},
|
||||
},
|
||||
{
|
||||
"sdns -> dot",
|
||||
&UpstreamConfig{
|
||||
Name: "sdns",
|
||||
Type: "sdns",
|
||||
Endpoint: "sdns://AwcAAAAAAAAACjc2Ljc2LjIuMTEAFGZyZWVkbnMuY29udHJvbGQuY29t",
|
||||
BootstrapIP: "",
|
||||
Domain: "",
|
||||
Timeout: 0,
|
||||
IPStack: IpStackBoth,
|
||||
},
|
||||
&UpstreamConfig{
|
||||
Name: "sdns",
|
||||
Type: "dot",
|
||||
Endpoint: "freedns.controld.com:843",
|
||||
BootstrapIP: "76.76.2.11",
|
||||
Domain: "freedns.controld.com",
|
||||
Timeout: 0,
|
||||
IPStack: IpStackBoth,
|
||||
},
|
||||
},
|
||||
{
|
||||
"sdns -> doq",
|
||||
&UpstreamConfig{
|
||||
Name: "sdns",
|
||||
Type: "sdns",
|
||||
Endpoint: "sdns://BAcAAAAAAAAACjc2Ljc2LjIuMTEAFGZyZWVkbnMuY29udHJvbGQuY29t",
|
||||
BootstrapIP: "",
|
||||
Domain: "",
|
||||
Timeout: 0,
|
||||
IPStack: IpStackBoth,
|
||||
},
|
||||
&UpstreamConfig{
|
||||
Name: "sdns",
|
||||
Type: "doq",
|
||||
Endpoint: "freedns.controld.com:784",
|
||||
BootstrapIP: "76.76.2.11",
|
||||
Domain: "freedns.controld.com",
|
||||
Timeout: 0,
|
||||
IPStack: IpStackBoth,
|
||||
},
|
||||
},
|
||||
{
|
||||
"sdns -> legacy",
|
||||
&UpstreamConfig{
|
||||
Name: "sdns",
|
||||
Type: "sdns",
|
||||
Endpoint: "sdns://AAcAAAAAAAAACjc2Ljc2LjIuMTE",
|
||||
BootstrapIP: "",
|
||||
Domain: "",
|
||||
Timeout: 0,
|
||||
IPStack: IpStackBoth,
|
||||
},
|
||||
&UpstreamConfig{
|
||||
Name: "sdns",
|
||||
Type: "legacy",
|
||||
Endpoint: "76.76.2.11:53",
|
||||
BootstrapIP: "76.76.2.11",
|
||||
Domain: "76.76.2.11",
|
||||
Timeout: 0,
|
||||
IPStack: IpStackBoth,
|
||||
},
|
||||
},
|
||||
{
|
||||
"sdns without type",
|
||||
&UpstreamConfig{
|
||||
Name: "sdns",
|
||||
Endpoint: "sdns://AAcAAAAAAAAACjc2Ljc2LjIuMTE",
|
||||
BootstrapIP: "",
|
||||
Domain: "",
|
||||
Timeout: 0,
|
||||
IPStack: IpStackBoth,
|
||||
},
|
||||
&UpstreamConfig{
|
||||
Name: "sdns",
|
||||
Type: "legacy",
|
||||
Endpoint: "76.76.2.11:53",
|
||||
BootstrapIP: "76.76.2.11",
|
||||
Domain: "76.76.2.11",
|
||||
Timeout: 0,
|
||||
IPStack: IpStackBoth,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
tc.uc.Init()
|
||||
tc.uc.Init(context.Background())
|
||||
tc.uc.uid = "" // we don't care about the uid.
|
||||
assert.Equal(t, tc.expected, tc.uc)
|
||||
})
|
||||
@@ -326,7 +498,7 @@ func TestUpstreamConfig_IsDiscoverable(t *testing.T) {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
tc.uc.Init()
|
||||
tc.uc.Init(context.Background())
|
||||
if got := tc.uc.IsDiscoverable(); got != tc.discoverable {
|
||||
t.Errorf("unexpected result, want: %v, got: %v", tc.discoverable, got)
|
||||
}
|
||||
|
||||
@@ -14,34 +14,35 @@ import (
|
||||
"github.com/quic-go/quic-go/http3"
|
||||
)
|
||||
|
||||
func (uc *UpstreamConfig) setupDOH3Transport() {
|
||||
func (uc *UpstreamConfig) setupDOH3Transport(ctx context.Context) {
|
||||
switch uc.IPStack {
|
||||
case IpStackBoth, "":
|
||||
uc.http3RoundTripper = uc.newDOH3Transport(uc.bootstrapIPs)
|
||||
uc.http3RoundTripper = uc.newDOH3Transport(ctx, uc.bootstrapIPs)
|
||||
case IpStackV4:
|
||||
uc.http3RoundTripper = uc.newDOH3Transport(uc.bootstrapIPs4)
|
||||
uc.http3RoundTripper = uc.newDOH3Transport(ctx, uc.bootstrapIPs4)
|
||||
case IpStackV6:
|
||||
uc.http3RoundTripper = uc.newDOH3Transport(uc.bootstrapIPs6)
|
||||
uc.http3RoundTripper = uc.newDOH3Transport(ctx, uc.bootstrapIPs6)
|
||||
case IpStackSplit:
|
||||
uc.http3RoundTripper4 = uc.newDOH3Transport(uc.bootstrapIPs4)
|
||||
if hasIPv6() {
|
||||
uc.http3RoundTripper6 = uc.newDOH3Transport(uc.bootstrapIPs6)
|
||||
uc.http3RoundTripper4 = uc.newDOH3Transport(ctx, uc.bootstrapIPs4)
|
||||
if HasIPv6(ctx) {
|
||||
uc.http3RoundTripper6 = uc.newDOH3Transport(ctx, uc.bootstrapIPs6)
|
||||
} else {
|
||||
uc.http3RoundTripper6 = uc.http3RoundTripper4
|
||||
}
|
||||
uc.http3RoundTripper = uc.newDOH3Transport(uc.bootstrapIPs)
|
||||
uc.http3RoundTripper = uc.newDOH3Transport(ctx, uc.bootstrapIPs)
|
||||
}
|
||||
}
|
||||
|
||||
func (uc *UpstreamConfig) newDOH3Transport(addrs []string) http.RoundTripper {
|
||||
rt := &http3.RoundTripper{}
|
||||
func (uc *UpstreamConfig) newDOH3Transport(ctx context.Context, addrs []string) http.RoundTripper {
|
||||
rt := &http3.Transport{}
|
||||
rt.TLSClientConfig = &tls.Config{RootCAs: uc.certPool}
|
||||
rt.Dial = func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) {
|
||||
logger := LoggerFromCtx(ctx)
|
||||
rt.Dial = func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (*quic.Conn, error) {
|
||||
_, port, _ := net.SplitHostPort(addr)
|
||||
// if we have a bootstrap ip set, use it to avoid DNS lookup
|
||||
if uc.BootstrapIP != "" {
|
||||
addr = net.JoinHostPort(uc.BootstrapIP, port)
|
||||
ProxyLogger.Load().Debug().Msgf("sending doh3 request to: %s", addr)
|
||||
Log(ctx, logger.Debug(), "Sending doh3 request to: %s", addr)
|
||||
udpConn, err := net.ListenUDP("udp", nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -61,21 +62,21 @@ func (uc *UpstreamConfig) newDOH3Transport(addrs []string) http.RoundTripper {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ProxyLogger.Load().Debug().Msgf("sending doh3 request to: %s", conn.RemoteAddr())
|
||||
Log(ctx, logger.Debug(), "Sending doh3 request to: %s", conn.RemoteAddr())
|
||||
return conn, err
|
||||
}
|
||||
runtime.SetFinalizer(rt, func(rt *http3.RoundTripper) {
|
||||
runtime.SetFinalizer(rt, func(rt *http3.Transport) {
|
||||
rt.CloseIdleConnections()
|
||||
})
|
||||
return rt
|
||||
}
|
||||
|
||||
func (uc *UpstreamConfig) doh3Transport(dnsType uint16) http.RoundTripper {
|
||||
func (uc *UpstreamConfig) doh3Transport(ctx context.Context, dnsType uint16) http.RoundTripper {
|
||||
uc.transportOnce.Do(func() {
|
||||
uc.SetupTransport()
|
||||
uc.SetupTransport(ctx)
|
||||
})
|
||||
if uc.rebootstrap.CompareAndSwap(true, false) {
|
||||
uc.SetupTransport()
|
||||
uc.SetupTransport(ctx)
|
||||
}
|
||||
switch uc.IPStack {
|
||||
case IpStackBoth, IpStackV4, IpStackV6:
|
||||
@@ -96,14 +97,14 @@ func (uc *UpstreamConfig) doh3Transport(dnsType uint16) http.RoundTripper {
|
||||
// - quic dialer is different with net.Dialer
|
||||
// - simplification for quic free version
|
||||
type parallelDialerResult struct {
|
||||
conn quic.EarlyConnection
|
||||
conn *quic.Conn
|
||||
err error
|
||||
}
|
||||
|
||||
type quicParallelDialer struct{}
|
||||
|
||||
// Dial performs parallel dialing to the given address list.
|
||||
func (d *quicParallelDialer) Dial(ctx context.Context, addrs []string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) {
|
||||
func (d *quicParallelDialer) Dial(ctx context.Context, addrs []string, tlsCfg *tls.Config, cfg *quic.Config) (*quic.Conn, error) {
|
||||
if len(addrs) == 0 {
|
||||
return nil, errors.New("empty addresses")
|
||||
}
|
||||
|
||||
@@ -107,7 +107,11 @@ func TestConfigValidation(t *testing.T) {
|
||||
{"invalid doh/doh3 endpoint", configWithInvalidDoHEndpoint(t), true},
|
||||
{"invalid client id pref", configWithInvalidClientIDPref(t), true},
|
||||
{"doh endpoint without scheme", dohUpstreamEndpointWithoutScheme(t), false},
|
||||
{"doh endpoint without type", dohUpstreamEndpointWithoutType(t), true},
|
||||
{"doh3 endpoint without type", doh3UpstreamEndpointWithoutType(t), false},
|
||||
{"sdns endpoint without type", sdnsUpstreamEndpointWithoutType(t), false},
|
||||
{"maximum number of flush cache domains", configWithInvalidFlushCacheDomain(t), true},
|
||||
{"kea dhcp4 format", configWithDhcp4KeaFormat(t), false},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
@@ -127,6 +131,21 @@ func TestConfigValidation(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigValidationDoNotChangeEndpoint(t *testing.T) {
|
||||
cfg := configWithInvalidDoHEndpoint(t)
|
||||
endpointMap := map[string]struct{}{}
|
||||
for _, uc := range cfg.Upstream {
|
||||
endpointMap[uc.Endpoint] = struct{}{}
|
||||
}
|
||||
validate := validator.New()
|
||||
_ = ctrld.ValidateConfig(validate, cfg)
|
||||
for _, uc := range cfg.Upstream {
|
||||
if _, ok := endpointMap[uc.Endpoint]; !ok {
|
||||
t.Fatalf("expected endpoint '%s' to exist", uc.Endpoint)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigDiscoverOverride(t *testing.T) {
|
||||
v := viper.NewWithOptions(viper.KeyDelimiter("::"))
|
||||
ctrld.InitConfig(v, "test_config_discover_override")
|
||||
@@ -179,6 +198,27 @@ func dohUpstreamEndpointWithoutScheme(t *testing.T) *ctrld.Config {
|
||||
return cfg
|
||||
}
|
||||
|
||||
func dohUpstreamEndpointWithoutType(t *testing.T) *ctrld.Config {
|
||||
cfg := defaultConfig(t)
|
||||
cfg.Upstream["0"].Endpoint = "https://freedns.controld.com/p1"
|
||||
cfg.Upstream["0"].Type = ""
|
||||
return cfg
|
||||
}
|
||||
|
||||
func doh3UpstreamEndpointWithoutType(t *testing.T) *ctrld.Config {
|
||||
cfg := defaultConfig(t)
|
||||
cfg.Upstream["0"].Endpoint = "h3://freedns.controld.com/p1"
|
||||
cfg.Upstream["0"].Type = ""
|
||||
return cfg
|
||||
}
|
||||
|
||||
func sdnsUpstreamEndpointWithoutType(t *testing.T) *ctrld.Config {
|
||||
cfg := defaultConfig(t)
|
||||
cfg.Upstream["0"].Endpoint = "sdns://AgMAAAAAAAAACjc2Ljc2LjIuMTEAFGZyZWVkbnMuY29udHJvbGQuY29tAy9wMQ"
|
||||
cfg.Upstream["0"].Type = ""
|
||||
return cfg
|
||||
}
|
||||
|
||||
func invalidUpstreamTimeout(t *testing.T) *ctrld.Config {
|
||||
cfg := defaultConfig(t)
|
||||
cfg.Upstream["0"].Timeout = -1
|
||||
@@ -268,6 +308,12 @@ func configWithInvalidLeaseFileFormat(t *testing.T) *ctrld.Config {
|
||||
return cfg
|
||||
}
|
||||
|
||||
func configWithDhcp4KeaFormat(t *testing.T) *ctrld.Config {
|
||||
cfg := defaultConfig(t)
|
||||
cfg.Service.DHCPLeaseFileFormat = "kea-dhcp4"
|
||||
return cfg
|
||||
}
|
||||
|
||||
func configWithInvalidDoHEndpoint(t *testing.T) *ctrld.Config {
|
||||
cfg := defaultConfig(t)
|
||||
cfg.Upstream["0"].Endpoint = "/1.1.1.1"
|
||||
|
||||
10
desktop_darwin.go
Normal file
10
desktop_darwin.go
Normal file
@@ -0,0 +1,10 @@
|
||||
package ctrld
|
||||
|
||||
// IsDesktopPlatform indicates if ctrld is running on a desktop platform,
|
||||
// currently defined as macOS or Windows workstation.
|
||||
func IsDesktopPlatform() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// SelfDiscover reports whether ctrld should only do self discover.
|
||||
func SelfDiscover() bool { return true }
|
||||
12
desktop_others.go
Normal file
12
desktop_others.go
Normal file
@@ -0,0 +1,12 @@
|
||||
//go:build !windows && !darwin
|
||||
|
||||
package ctrld
|
||||
|
||||
// IsDesktopPlatform indicates if ctrld is running on a desktop platform,
|
||||
// currently defined as macOS or Windows workstation.
|
||||
func IsDesktopPlatform() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// SelfDiscover reports whether ctrld should only do self discover.
|
||||
func SelfDiscover() bool { return false }
|
||||
22
desktop_windows.go
Normal file
22
desktop_windows.go
Normal file
@@ -0,0 +1,22 @@
|
||||
package ctrld
|
||||
|
||||
import "golang.org/x/sys/windows"
|
||||
|
||||
// IsDesktopPlatform indicates if ctrld is running on a desktop platform,
|
||||
// currently defined as macOS or Windows workstation.
|
||||
func IsDesktopPlatform() bool {
|
||||
return isWindowsWorkStation()
|
||||
}
|
||||
|
||||
// SelfDiscover reports whether ctrld should only do self discover.
|
||||
func SelfDiscover() bool {
|
||||
return isWindowsWorkStation()
|
||||
}
|
||||
|
||||
// isWindowsWorkStation reports whether ctrld was run on a Windows workstation machine.
|
||||
func isWindowsWorkStation() bool {
|
||||
// From https://learn.microsoft.com/en-us/windows/win32/api/winnt/ns-winnt-osversioninfoexa
|
||||
const VER_NT_WORKSTATION = 0x0000001
|
||||
osvi := windows.RtlGetVersion()
|
||||
return osvi.ProductType == VER_NT_WORKSTATION
|
||||
}
|
||||
30
dns.go
Normal file
30
dns.go
Normal file
@@ -0,0 +1,30 @@
|
||||
package ctrld
|
||||
|
||||
import (
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
// SetCacheReply extracts and stores the necessary data from the message for a cached answer.
|
||||
func SetCacheReply(answer, msg *dns.Msg, code int) {
|
||||
answer.SetRcode(msg, code)
|
||||
cCookie := getEdns0Cookie(msg.IsEdns0())
|
||||
sCookie := getEdns0Cookie(answer.IsEdns0())
|
||||
if cCookie != nil && sCookie != nil {
|
||||
// Client cookie is fixed size 8 bytes, Server cookie is variable size 8 -> 32 bytes.
|
||||
// See https://datatracker.ietf.org/doc/html/rfc7873#section-4
|
||||
sCookie.Cookie = cCookie.Cookie[:16] + sCookie.Cookie[16:]
|
||||
}
|
||||
}
|
||||
|
||||
// getEdns0Cookie returns Edns0 cookie from *dns.OPT if present.
|
||||
func getEdns0Cookie(opt *dns.OPT) *dns.EDNS0_COOKIE {
|
||||
if opt == nil {
|
||||
return nil
|
||||
}
|
||||
for _, o := range opt.Option {
|
||||
if e, ok := o.(*dns.EDNS0_COOKIE); ok {
|
||||
return e
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -18,10 +18,6 @@ The config file allows for advanced configuration of the `ctrld` utility to cove
|
||||
|
||||
- `/etc/controld` on *nix.
|
||||
- User's home directory on Windows.
|
||||
- Same directory with `ctrld` binary on these routers:
|
||||
- `ddwrt`
|
||||
- `merlin`
|
||||
- `freshtomato`
|
||||
- Current directory.
|
||||
|
||||
The user can choose to override default value using command line `--config` or `-c`:
|
||||
@@ -166,7 +162,6 @@ before serving the query.
|
||||
|
||||
### max_concurrent_requests
|
||||
The number of concurrent requests that will be handled, must be a non-negative integer.
|
||||
Tweaking this value depends on the capacity of your system.
|
||||
|
||||
- Type: number
|
||||
- Required: no
|
||||
@@ -179,6 +174,8 @@ Perform LAN client discovery using mDNS. This will spawn a listener on port 5353
|
||||
- Required: no
|
||||
- Default: true
|
||||
|
||||
This config is ignored, and always set to `false` on Windows Desktop and Macos.
|
||||
|
||||
### discover_arp
|
||||
Perform LAN client discovery using ARP.
|
||||
|
||||
@@ -186,6 +183,8 @@ Perform LAN client discovery using ARP.
|
||||
- Required: no
|
||||
- Default: true
|
||||
|
||||
This config is ignored, and always set to `false` on Windows Desktop and Macos.
|
||||
|
||||
### discover_dhcp
|
||||
Perform LAN client discovery using DHCP leases files. Common file locations are auto-discovered.
|
||||
|
||||
@@ -193,6 +192,8 @@ Perform LAN client discovery using DHCP leases files. Common file locations are
|
||||
- Required: no
|
||||
- Default: true
|
||||
|
||||
This config is ignored, and always set to `false` on Windows Desktop and Macos.
|
||||
|
||||
### discover_ptr
|
||||
Perform LAN client discovery using PTR queries.
|
||||
|
||||
@@ -200,6 +201,8 @@ Perform LAN client discovery using PTR queries.
|
||||
- Required: no
|
||||
- Default: true
|
||||
|
||||
This config is ignored, and always set to `false` on Windows Desktop and Macos.
|
||||
|
||||
### discover_hosts
|
||||
Perform LAN client discovery using hosts file.
|
||||
|
||||
@@ -207,6 +210,8 @@ Perform LAN client discovery using hosts file.
|
||||
- Required: no
|
||||
- Default: true
|
||||
|
||||
This config is ignored, and always set to `false` on Windows Desktop and Macos.
|
||||
|
||||
### discover_refresh_interval
|
||||
Time in seconds between each discovery refresh loop to update new client information data.
|
||||
The default value is 120 seconds, lower this value to make the discovery process run more aggressively.
|
||||
@@ -253,9 +258,7 @@ Specifying the `ip` and `port` of the Prometheus metrics server. The Prometheus
|
||||
- Default: ""
|
||||
|
||||
### dns_watchdog_enabled
|
||||
Checking DNS changes to network interfaces and reverting to ctrld's own settings.
|
||||
|
||||
The DNS watchdog process only runs on Windows and MacOS.
|
||||
Watches all physical interfaces for DNS changes and reverts them to ctrld's settings.The DNS watchdog process only runs on Windows and MacOS.
|
||||
|
||||
- Type: boolean
|
||||
- Required: no
|
||||
@@ -274,13 +277,20 @@ If the time duration is non-positive, default value will be used.
|
||||
- Default: 20s
|
||||
|
||||
### refetch_time
|
||||
Time in seconds between each iteration that reloads custom config if changed.
|
||||
Time in seconds between each iteration that reloads custom config from the API.
|
||||
|
||||
The value must be a positive number, any invalid value will be ignored and default value will be used.
|
||||
- Type: number
|
||||
- Required: no
|
||||
- Default: 3600
|
||||
|
||||
### leak_on_upstream_failure
|
||||
If a remote upstream fails to resolve a query or is unreachable, `ctrld` will forward the queries to the default DNS resolver on the network. If failures persist, `ctrld` will remove itself from all networking interfaces until connectivity is restored.
|
||||
|
||||
- Type: boolean
|
||||
- Required: no
|
||||
- Default: true on Windows, MacOS and Linux.
|
||||
|
||||
## Upstream
|
||||
The `[upstream]` section specifies the DNS upstream servers that `ctrld` will forward DNS requests to.
|
||||
|
||||
@@ -524,6 +534,15 @@ rules = [
|
||||
]
|
||||
```
|
||||
|
||||
If there is no explicitly defined rules, LAN queries will be handled solely by the OS resolver.
|
||||
|
||||
These following domains are considered LAN queries:
|
||||
|
||||
- Queries does not have dot `.` in domain name, like `machine1`, `example`, ... (1)
|
||||
- Queries have domain ends with: `.domain`, `.lan`, `.local`. (2)
|
||||
- All `SRV` queries of LAN hostname (1) + (2).
|
||||
- `PTR` queries with private IPs.
|
||||
|
||||
---
|
||||
|
||||
Note that the order of matching preference:
|
||||
@@ -557,6 +576,12 @@ And within each policy, the rules are processed from top to bottom.
|
||||
- Required: no
|
||||
- Default: []
|
||||
|
||||
---
|
||||
|
||||
Note that the domain comparisons are done in case in-sensitive manner following [RFC 1034](https://datatracker.ietf.org/doc/html/rfc1034#section-3.1)
|
||||
|
||||
---
|
||||
|
||||
### macs:
|
||||
`macs` is the list of mac rules within the policy. Mac address value is case-insensitive.
|
||||
|
||||
|
||||
42
docs/known-issues.md
Normal file
42
docs/known-issues.md
Normal file
@@ -0,0 +1,42 @@
|
||||
# Known Issues
|
||||
|
||||
This document outlines known issues with ctrld and their current status, workarounds, and recommendations.
|
||||
|
||||
## macOS (Darwin) Issues
|
||||
|
||||
### Self-Upgrade Issue on Darwin 15.5
|
||||
|
||||
**Issue**: ctrld self-upgrading functionality may not work on macOS Darwin 15.5.
|
||||
|
||||
**Status**: Under investigation
|
||||
|
||||
**Description**: Users on macOS Darwin 15.5 may experience issues when ctrld attempts to perform automatic self-upgrades. The upgrade process would be triggered, but ctrld won't be upgraded.
|
||||
|
||||
**Workarounds**:
|
||||
1. **Recommended**: Upgrade your macOS system to Darwin 15.6 or later, which has been tested and verified to work correctly with ctrld self-upgrade functionality.
|
||||
2. **Alternative**: Run `ctrld upgrade prod` directly to manually upgrade ctrld to the latest version on Darwin 15.5.
|
||||
|
||||
**Affected Versions**: ctrld v1.4.2 and later on macOS Darwin 15.5
|
||||
|
||||
**Last Updated**: 05/09/2025
|
||||
|
||||
---
|
||||
|
||||
## Contributing to Known Issues
|
||||
|
||||
If you encounter an issue not listed here, please:
|
||||
|
||||
1. Check the [GitHub Issues](https://github.com/Control-D-Inc/ctrld/issues) to see if it's already reported
|
||||
2. If not reported, create a new issue with:
|
||||
- Detailed description of the problem
|
||||
- Steps to reproduce
|
||||
- Expected vs actual behavior
|
||||
- System information (OS, version, architecture)
|
||||
- ctrld version
|
||||
|
||||
## Issue Status Legend
|
||||
|
||||
- **Under investigation**: Issue is confirmed and being analyzed
|
||||
- **Workaround available**: Temporary solution exists while permanent fix is developed
|
||||
- **Fixed**: Issue has been resolved in a specific version
|
||||
- **Won't fix**: Issue is acknowledged but will not be addressed due to technical limitations or design decisions
|
||||
46
docs/runtime-internal-logging.md
Normal file
46
docs/runtime-internal-logging.md
Normal file
@@ -0,0 +1,46 @@
|
||||
# Runtime Internal Logging
|
||||
|
||||
When no logging is configured (i.e., `log_path` is not set), ctrld automatically enables an internal logging system. This system stores logs in memory to provide troubleshooting information when problems occur.
|
||||
|
||||
## Purpose
|
||||
|
||||
The runtime internal logging system is designed primarily for **ctrld developers**, not end users. It captures detailed diagnostic information that can be useful for troubleshooting issues when they arise, especially in production environments where explicit logging may not be configured.
|
||||
|
||||
## When It's Enabled
|
||||
|
||||
Internal logging is automatically enabled when:
|
||||
|
||||
- ctrld is running in Control D mode (i.e., `--cd` flag is provided)
|
||||
- No log file is configured (i.e., `log_path` is empty or not set)
|
||||
|
||||
If a log file is explicitly configured via `log_path`, internal logging will **not** be enabled, as the configured log file serves the logging purpose.
|
||||
|
||||
## How It Works
|
||||
|
||||
The internal logging system:
|
||||
|
||||
- Stores logs in **in-memory buffers** (not written to disk)
|
||||
- Captures logs at **debug level** for normal operations and **warn level** for warnings
|
||||
- Maintains separate buffers for normal logs and warning logs
|
||||
- Automatically manages buffer size to prevent unbounded memory growth
|
||||
- Preserves initialization logs even when buffers overflow
|
||||
|
||||
## Configuration
|
||||
|
||||
**Important**: The `log_level` configuration option does **not** affect the internal logging system. Internal logging always operates at debug level for normal logs and warn level for warnings, regardless of the `log_level` setting in the configuration file.
|
||||
|
||||
The `log_level` setting only affects:
|
||||
- Console output (when running interactively)
|
||||
- File-based logging (when `log_path` is configured)
|
||||
|
||||
## Accessing Internal Logs
|
||||
|
||||
Internal logs can be accessed through the control server API endpoints. This functionality is intended for developers and support personnel who need to diagnose issues.
|
||||
|
||||
## Notes
|
||||
|
||||
- Internal logging is **not** a replacement for proper log file configuration in production environments
|
||||
- For production deployments, it is recommended to configure `log_path` to enable persistent file-based logging
|
||||
- Internal logs are stored in memory and will be lost if the process terminates unexpectedly
|
||||
- The internal logging system is automatically disabled when explicit logging is configured
|
||||
|
||||
135
docs/v2.0.0-breaking-changes.md
Normal file
135
docs/v2.0.0-breaking-changes.md
Normal file
@@ -0,0 +1,135 @@
|
||||
# ctrld v2.0.0 Breaking Changes
|
||||
|
||||
This document outlines the breaking changes introduced in ctrld v2.0.0 and provides migration guidance for affected users.
|
||||
|
||||
## Overview
|
||||
|
||||
ctrld v2.0.0 removes automatic configuration support for router and server platforms. This means ctrld will no longer perform "magic" configuration to automatically set itself up as an upstream for existing DNS software on these platforms.
|
||||
|
||||
## What's Changing
|
||||
|
||||
### Removed Platform Support
|
||||
|
||||
**Router Platforms:**
|
||||
- ctrld will no longer automatically configure itself as an upstream for dnsmasq or other DNS software
|
||||
- No automatic detection and configuration of router-specific DNS settings
|
||||
|
||||
**Server Platforms:**
|
||||
- ctrld will no longer automatically configure Windows Server DNS forwarder settings
|
||||
- No automatic integration with server DNS services
|
||||
|
||||
### What Remains Supported
|
||||
|
||||
**Desktop Platforms:**
|
||||
- Windows Desktop
|
||||
- macOS Desktop
|
||||
- Linux Desktop
|
||||
|
||||
These platforms continue to receive full automatic configuration support.
|
||||
|
||||
## Stay on v1.x.x
|
||||
|
||||
ctrld v1.x.x will continue to be supported for router and server platforms:
|
||||
- Important bug fixes (regression or security) will be cherry-picked to v1.x.x branch
|
||||
- New features may still be added (but may take longer to implement)
|
||||
- Long-term support for these platforms
|
||||
|
||||
## Migration Path for Router and Server Users
|
||||
|
||||
If you're currently using ctrld v1.x.x on router or server platforms, you need to follow these steps to migrate to v2.0.0:
|
||||
|
||||
### Step 1: Downloading ctrld v2 binary
|
||||
|
||||
To download ctrld v2.0.0, follow these steps:
|
||||
|
||||
Stop the current ctrld service:
|
||||
|
||||
```sh
|
||||
ctrld stop
|
||||
```
|
||||
|
||||
Or uninstall the current version:
|
||||
|
||||
```sh
|
||||
ctrld uninstall
|
||||
```
|
||||
|
||||
Download the appropriate binary for your platform: https://dl.controld.com/v2/linux-amd64/ctrld
|
||||
|
||||
> **Note**: Replace `amd64` with your platform architecture as needed.
|
||||
|
||||
Verify that the binary was updated correctly:
|
||||
|
||||
```sh
|
||||
ctrld --version
|
||||
```
|
||||
|
||||
Expected output:
|
||||
```
|
||||
ctrld version v2.0.0
|
||||
```
|
||||
|
||||
### Step 2: Start ctrld without self-checking
|
||||
|
||||
You have two ways to start ctrld:
|
||||
|
||||
**Option A: Use Remote Configuration (Recommended)**
|
||||
1. **Export your current configuration:**
|
||||
- Copy the contents of your current `ctrld.toml` file
|
||||
|
||||
2. **Import to Control D Dashboard:**
|
||||
- Log into your Control D dashboard
|
||||
- Use the remote configuration feature to upload your configuration
|
||||
|
||||
3. **Start ctrld with remote config:**
|
||||
```bash
|
||||
sudo ctrld service start --cd=<your_uid> --skip_self_checks
|
||||
```
|
||||
|
||||
> **Note**: You must use `ctrld service start` to prevent DNS being set automatically by ctrld.
|
||||
|
||||
**Option B: Use Local Configuration**
|
||||
```bash
|
||||
sudo ctrld service start --skip_self_checks
|
||||
```
|
||||
|
||||
### Step 3: Configure DNS Software to Use ctrld as Upstream
|
||||
|
||||
**For dnsmasq users:**
|
||||
1. Configure dnsmasq to use ctrld as upstream:
|
||||
```bash
|
||||
# Add to dnsmasq.conf
|
||||
no-resolv
|
||||
server=127.0.0.1#5354
|
||||
add-mac
|
||||
add-subnet=32,128
|
||||
# Disable cache or set max-cache-ttl=0
|
||||
# to prevent queries from caching
|
||||
cache-size=0
|
||||
# max-cache-ttl=0
|
||||
```
|
||||
2. Restart dnsmasq:
|
||||
```bash
|
||||
sudo service dnsmasq restart
|
||||
```
|
||||
|
||||
**For Windows Server users:**
|
||||
1. Configure DNS forwarder in Windows Server:
|
||||
- Open DNS Manager
|
||||
- Right-click on your server name
|
||||
- Select "Properties" → "Forwarders" tab
|
||||
- Add `<ctrld listener IP>` as a forwarder
|
||||
|
||||
## Getting Help
|
||||
|
||||
If you encounter any issues during migration or have questions about the v2.0.0 changes:
|
||||
|
||||
1. **File an issue:** [GitHub Issues](https://github.com/Control-D-Inc/ctrld/issues)
|
||||
2. **Contact support:** Email help@controld.com.
|
||||
3. **Check documentation:** Review the [configuration documentation](config.md) for detailed setup instructions
|
||||
|
||||
## Summary
|
||||
|
||||
While ctrld v2.0.0 removes automatic configuration for router and server platforms, it provides a more focused experience for desktop users while still allowing router/server users to continue using ctrld with manual configuration or by staying on the v1.x.x branch.
|
||||
|
||||
The migration path is designed to be straightforward, with multiple options to suit different use cases and technical comfort levels.
|
||||
82
doh.go
82
doh.go
@@ -2,6 +2,7 @@ package ctrld
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
@@ -52,6 +53,9 @@ var EncodeArchNameMap = map[string]string{
|
||||
var DecodeArchNameMap = map[string]string{}
|
||||
|
||||
func init() {
|
||||
// Create reverse mappings for OS and architecture names
|
||||
// This is needed because the API expects encoded values, but we need to decode
|
||||
// them back to their original form for processing
|
||||
for k, v := range EncodeOsNameMap {
|
||||
DecodeOsNameMap[v] = k
|
||||
}
|
||||
@@ -84,8 +88,12 @@ type dohResolver struct {
|
||||
|
||||
// Resolve performs DNS query with given DNS message using DOH protocol.
|
||||
func (r *dohResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) {
|
||||
logger := LoggerFromCtx(ctx)
|
||||
Log(ctx, logger.Debug(), "DoH resolver query started")
|
||||
|
||||
data, err := msg.Pack()
|
||||
if err != nil {
|
||||
Log(ctx, logger.Error().Err(err), "Failed to pack DNS message")
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -97,6 +105,7 @@ func (r *dohResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, erro
|
||||
endpoint.RawQuery = query.Encode()
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint.String(), nil)
|
||||
if err != nil {
|
||||
Log(ctx, logger.Error().Err(err), "Could not create HTTP request")
|
||||
return nil, fmt.Errorf("could not create request: %w", err)
|
||||
}
|
||||
addHeader(ctx, req, r.uc)
|
||||
@@ -104,38 +113,55 @@ func (r *dohResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, erro
|
||||
if len(msg.Question) > 0 {
|
||||
dnsTyp = msg.Question[0].Qtype
|
||||
}
|
||||
c := http.Client{Transport: r.uc.dohTransport(dnsTyp)}
|
||||
c := http.Client{Transport: r.uc.dohTransport(ctx, dnsTyp)}
|
||||
if r.isDoH3 {
|
||||
transport := r.uc.doh3Transport(dnsTyp)
|
||||
transport := r.uc.doh3Transport(ctx, dnsTyp)
|
||||
if transport == nil {
|
||||
Log(ctx, logger.Error(), "DoH3 is not supported")
|
||||
return nil, errors.New("DoH3 is not supported")
|
||||
}
|
||||
c.Transport = transport
|
||||
}
|
||||
|
||||
Log(ctx, logger.Debug(), "Sending DoH request to: %s", endpoint.String())
|
||||
resp, err := c.Do(req)
|
||||
if err != nil && r.uc.FallbackToDirectIP(ctx) {
|
||||
retryCtx, cancel := r.uc.Context(context.WithoutCancel(ctx))
|
||||
defer cancel()
|
||||
logger := LoggerFromCtx(ctx)
|
||||
logger.Warn().Err(err).Msg("Retrying request after fallback to direct ip")
|
||||
resp, err = c.Do(req.Clone(retryCtx))
|
||||
}
|
||||
if err != nil {
|
||||
err = wrapUrlError(err)
|
||||
if r.isDoH3 {
|
||||
if closer, ok := c.Transport.(io.Closer); ok {
|
||||
closer.Close()
|
||||
}
|
||||
}
|
||||
Log(ctx, logger.Error().Err(err), "DoH request failed")
|
||||
return nil, fmt.Errorf("could not perform request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
buf, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
Log(ctx, logger.Error().Err(err), "Could not read response body")
|
||||
return nil, fmt.Errorf("could not read message from response: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
Log(ctx, logger.Error(), "Wrong response from DOH server, got: %s, status: %d", string(buf), resp.StatusCode)
|
||||
return nil, fmt.Errorf("wrong response from DOH server, got: %s, status: %d", string(buf), resp.StatusCode)
|
||||
}
|
||||
|
||||
answer := new(dns.Msg)
|
||||
if err := answer.Unpack(buf); err != nil {
|
||||
Log(ctx, logger.Error().Err(err), "Failed to unpack DNS answer")
|
||||
return nil, fmt.Errorf("answer.Unpack: %w", err)
|
||||
}
|
||||
|
||||
Log(ctx, logger.Debug(), "DoH resolver query successful")
|
||||
return answer, nil
|
||||
}
|
||||
|
||||
@@ -155,7 +181,8 @@ func addHeader(ctx context.Context, req *http.Request, uc *UpstreamConfig) {
|
||||
}
|
||||
}
|
||||
if printed {
|
||||
Log(ctx, ProxyLogger.Load().Debug(), "sending request header: %v", dohHeader)
|
||||
logger := LoggerFromCtx(ctx)
|
||||
Log(ctx, logger.Debug(), "Sending request header: %v", dohHeader)
|
||||
}
|
||||
dohHeader.Set("Content-Type", headerApplicationDNS)
|
||||
dohHeader.Set("Accept", headerApplicationDNS)
|
||||
@@ -202,3 +229,52 @@ func newNextDNSHeaders(ci *ClientInfo) http.Header {
|
||||
}
|
||||
return header
|
||||
}
|
||||
|
||||
// wrapCertificateVerificationError wraps a certificate verification error with additional context about the certificate issuer.
|
||||
// It extracts information like the issuer, organization, and subject from the certificate for a more descriptive error output.
|
||||
// If no certificate-related information is available, it simply returns the original error unmodified.
|
||||
func wrapCertificateVerificationError(err error) error {
|
||||
var tlsErr *tls.CertificateVerificationError
|
||||
if errors.As(err, &tlsErr) {
|
||||
if len(tlsErr.UnverifiedCertificates) > 0 {
|
||||
cert := tlsErr.UnverifiedCertificates[0]
|
||||
// Extract a more user-friendly issuer name
|
||||
var issuer string
|
||||
var organization string
|
||||
if len(cert.Issuer.Organization) > 0 {
|
||||
organization = cert.Issuer.Organization[0]
|
||||
issuer = organization
|
||||
} else if cert.Issuer.CommonName != "" {
|
||||
issuer = cert.Issuer.CommonName
|
||||
} else {
|
||||
issuer = cert.Issuer.String()
|
||||
}
|
||||
|
||||
// Get the organization from the subject field as well
|
||||
if len(cert.Subject.Organization) > 0 {
|
||||
organization = cert.Subject.Organization[0]
|
||||
}
|
||||
|
||||
// Extract the subject information
|
||||
subjectCN := cert.Subject.CommonName
|
||||
if subjectCN == "" && len(cert.Subject.Organization) > 0 {
|
||||
subjectCN = cert.Subject.Organization[0]
|
||||
}
|
||||
return fmt.Errorf("%w: %s, %s, %s", tlsErr, subjectCN, organization, issuer)
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// wrapUrlError inspects and wraps a URL error, focusing on certificate verification errors for detailed context.
|
||||
func wrapUrlError(err error) error {
|
||||
var urlErr *url.Error
|
||||
if errors.As(err, &urlErr) {
|
||||
var tlsErr *tls.CertificateVerificationError
|
||||
if errors.As(urlErr.Err, &tlsErr) {
|
||||
urlErr.Err = wrapCertificateVerificationError(tlsErr)
|
||||
return urlErr
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user