mirror of
https://github.com/Control-D-Inc/ctrld.git
synced 2026-02-03 22:18:39 +00:00
Compare commits
198 Commits
release-br
...
v1.4.7
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3ca559e5a4 | ||
|
|
0e3f764299 | ||
|
|
e52402eb0c | ||
|
|
2133f31854 | ||
|
|
a198a5cd65 | ||
|
|
eb2b231bd2 | ||
|
|
7af29cfbc0 | ||
|
|
ce1a165348 | ||
|
|
fd48e6d795 | ||
|
|
d71d1341b6 | ||
|
|
21855df4af | ||
|
|
66e2d3a40a | ||
|
|
26257cf24a | ||
|
|
36a7423634 | ||
|
|
e616091249 | ||
|
|
0948161529 | ||
|
|
ce29b5d217 | ||
|
|
de24fa293e | ||
|
|
6663925c4d | ||
|
|
b9ece6d7b9 | ||
|
|
c4efa1ab97 | ||
|
|
7cea5305e1 | ||
|
|
a20fbf95de | ||
|
|
628c4302aa | ||
|
|
8dc34f8bf5 | ||
|
|
b4faf82f76 | ||
|
|
a983dfaee2 | ||
|
|
62f73bcaa2 | ||
|
|
00e9d2bdd3 | ||
|
|
ace3b1e66e | ||
|
|
d1ea1ba08c | ||
|
|
c06c8aa859 | ||
|
|
0c2cc00c4f | ||
|
|
8d6ea91f35 | ||
|
|
7dfb77228f | ||
|
|
24910f1fa6 | ||
|
|
433a61d2ee | ||
|
|
3937e885f0 | ||
|
|
c651003cc4 | ||
|
|
b7ccfcb8b4 | ||
|
|
a9ed70200b | ||
|
|
c6365e6b74 | ||
|
|
dacc67e50f | ||
|
|
c60cf33af3 | ||
|
|
f27cbe3525 | ||
|
|
2de1b9929a | ||
|
|
8bf654aece | ||
|
|
84376ed719 | ||
|
|
7a136b8874 | ||
|
|
58c0e4f15a | ||
|
|
e0d35d8ba2 | ||
|
|
3b2e48761e | ||
|
|
b27064008e | ||
|
|
1ad63827e1 | ||
|
|
20e61550c2 | ||
|
|
020b814402 | ||
|
|
e578867118 | ||
|
|
46a1039f21 | ||
|
|
cc9e27de5f | ||
|
|
6ab3ab9faf | ||
|
|
e68bfa795a | ||
|
|
e60a92e93e | ||
|
|
62fe14f76b | ||
|
|
a0c5062e3a | ||
|
|
49eb152d02 | ||
|
|
b05056423a | ||
|
|
c7168739c7 | ||
|
|
5b1faf1ce3 | ||
|
|
513a6f9ec7 | ||
|
|
8db6fa4232 | ||
|
|
5036de2602 | ||
|
|
332f8ccc37 | ||
|
|
a582195cec | ||
|
|
9fe36ae984 | ||
|
|
54cb455522 | ||
|
|
8bd3b9e474 | ||
|
|
eff5ff580b | ||
|
|
c45f863ed8 | ||
|
|
414d4e356d | ||
|
|
ef697eb781 | ||
|
|
0631ffe831 | ||
|
|
7444d8517a | ||
|
|
3480043e40 | ||
|
|
619b6e7516 | ||
|
|
0123ca44fb | ||
|
|
7929aafe2a | ||
|
|
dc433f8dc9 | ||
|
|
8ccaeeab60 | ||
|
|
043a28eb33 | ||
|
|
c329402f5d | ||
|
|
23e6ad6e1f | ||
|
|
e6de78c1fa | ||
|
|
a670708f93 | ||
|
|
4ebe2fb5f4 | ||
|
|
3403b2039d | ||
|
|
e30ad31e0f | ||
|
|
81e0bad739 | ||
|
|
7d07d738dc | ||
|
|
0fae584e65 | ||
|
|
9e83085f2a | ||
|
|
41a00c68ac | ||
|
|
e3b99bf339 | ||
|
|
5007a87d3a | ||
|
|
60e65a37a6 | ||
|
|
d37d0e942c | ||
|
|
98042d8dbd | ||
|
|
af4b826b68 | ||
|
|
253a57ca01 | ||
|
|
caf98b4dfe | ||
|
|
398f71fd00 | ||
|
|
e1301ade96 | ||
|
|
7a23f82192 | ||
|
|
715bcc4aa1 | ||
|
|
0c74838740 | ||
|
|
4b05b6da7b | ||
|
|
375844ff1a | ||
|
|
1d207379cb | ||
|
|
fb49cb71e3 | ||
|
|
9618efbcde | ||
|
|
bb2210b06a | ||
|
|
917052723d | ||
|
|
fef85cadeb | ||
|
|
4a05fb6b28 | ||
|
|
6644ce53f2 | ||
|
|
72f0b89fdc | ||
|
|
41a97a6609 | ||
|
|
38064d6ad5 | ||
|
|
ae6945cedf | ||
|
|
3132d1b032 | ||
|
|
2716ae29bd | ||
|
|
1c50c2b6af | ||
|
|
cf6d16b439 | ||
|
|
60686f55ff | ||
|
|
47d7ace3a7 | ||
|
|
2d3779ec27 | ||
|
|
595071b608 | ||
|
|
57ef717080 | ||
|
|
eb27d1482b | ||
|
|
f57972ead7 | ||
|
|
168eaf538b | ||
|
|
1560455ca3 | ||
|
|
028475a193 | ||
|
|
f7a6dbe39b | ||
|
|
e573a490c9 | ||
|
|
ce3281e70d | ||
|
|
0fbfd160c9 | ||
|
|
20759017e6 | ||
|
|
69e0aab73e | ||
|
|
7ed6733fb7 | ||
|
|
9718ab8579 | ||
|
|
2687a4a018 | ||
|
|
2d9c60dea1 | ||
|
|
841be069b7 | ||
|
|
7833132917 | ||
|
|
e9e63b0983 | ||
|
|
4df470b869 | ||
|
|
89600f6091 | ||
|
|
f986a575e8 | ||
|
|
9c2fe8d21f | ||
|
|
8bcbb9249e | ||
|
|
a95d50c0af | ||
|
|
5db7d3577b | ||
|
|
c53a0ca1c4 | ||
|
|
6fd3d1788a | ||
|
|
087c1975e5 | ||
|
|
3713cbecc3 | ||
|
|
6046789fa4 | ||
|
|
3ea69b180c | ||
|
|
db6e977e3a | ||
|
|
a5c776c846 | ||
|
|
5a566c028a | ||
|
|
ff43c74d8d | ||
|
|
3c7255569c | ||
|
|
4a92ec4d2d | ||
|
|
9bbccb4082 | ||
|
|
4f62314646 | ||
|
|
cb49d0d947 | ||
|
|
89f7874fc6 | ||
|
|
221917e80b | ||
|
|
37d41bd215 | ||
|
|
8a96b8bec4 | ||
|
|
02ee113b95 | ||
|
|
f71dd78915 | ||
|
|
cd5619a05b | ||
|
|
a63a30c76b | ||
|
|
f5ba8be182 | ||
|
|
a9f76322bd | ||
|
|
ed39269c80 | ||
|
|
09426dcd36 | ||
|
|
17941882a9 | ||
|
|
70ab8032a0 | ||
|
|
8360bdc50a | ||
|
|
6837176ec7 | ||
|
|
5e9b4244e7 | ||
|
|
9b6a308958 | ||
|
|
71e327653a | ||
|
|
a56711796f | ||
|
|
09495f2a7c |
6
.github/workflows/ci.yml
vendored
6
.github/workflows/ci.yml
vendored
@@ -9,7 +9,7 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
os: ["windows-latest", "ubuntu-latest", "macOS-latest"]
|
||||
go: ["1.23.x"]
|
||||
go: ["1.24.x"]
|
||||
runs-on: ${{ matrix.os }}
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
@@ -19,8 +19,8 @@ jobs:
|
||||
with:
|
||||
go-version: ${{ matrix.go }}
|
||||
- run: "go test -race ./..."
|
||||
- uses: dominikh/staticcheck-action@v1.2.0
|
||||
- uses: dominikh/staticcheck-action@v1.3.1
|
||||
with:
|
||||
version: "2024.1.1"
|
||||
version: "2025.1"
|
||||
install-go: false
|
||||
cache-key: ${{ matrix.go }}
|
||||
|
||||
198
README.md
198
README.md
@@ -4,12 +4,12 @@
|
||||
[](https://pkg.go.dev/github.com/Control-D-Inc/ctrld)
|
||||
[](https://goreportcard.com/report/github.com/Control-D-Inc/ctrld)
|
||||
|
||||

|
||||

|
||||
|
||||
A highly configurable DNS forwarding proxy with support for:
|
||||
- Multiple listeners for incoming queries
|
||||
- Multiple upstreams with fallbacks
|
||||
- Multiple network policy driven DNS query steering
|
||||
- Multiple network policy driven DNS query steering (via network cidr, MAC address or FQDN)
|
||||
- Policy driven domain based "split horizon" DNS with wildcard support
|
||||
- Integrations with common router vendors and firmware
|
||||
- LAN client discovery via DHCP, mDNS, ARP, NDP, hosts file parsing
|
||||
@@ -35,13 +35,29 @@ All DNS protocols are supported, including:
|
||||
|
||||
## OS Support
|
||||
- Windows (386, amd64, arm)
|
||||
- Mac (amd64, arm64)
|
||||
- Windows Server (386, amd64)
|
||||
- MacOS (amd64, arm64)
|
||||
- Linux (386, amd64, arm, mips)
|
||||
- FreeBSD
|
||||
- Common routers (See Router Mode below)
|
||||
- FreeBSD (386, amd64, arm)
|
||||
- Common routers (See below)
|
||||
|
||||
|
||||
### Supported Routers
|
||||
You can run `ctrld` on any supported router. The list of supported routers and firmware includes:
|
||||
- Asus Merlin
|
||||
- DD-WRT
|
||||
- Firewalla
|
||||
- FreshTomato
|
||||
- GL.iNet
|
||||
- OpenWRT
|
||||
- pfSense / OPNsense
|
||||
- Synology
|
||||
- Ubiquiti (UniFi, EdgeOS)
|
||||
|
||||
`ctrld` will attempt to interface with dnsmasq (or Windows Server) whenever possible and set itself as the upstream, while running on port 5354. On FreeBSD based OSes, `ctrld` will terminate dnsmasq and unbound in order to be able to listen on port 53 directly.
|
||||
|
||||
# Install
|
||||
There are several ways to download and install `ctrld.
|
||||
There are several ways to download and install `ctrld`.
|
||||
|
||||
## Quick Install
|
||||
The simplest way to download and install `ctrld` is to use the following installer command on any UNIX-like platform:
|
||||
@@ -50,14 +66,14 @@ The simplest way to download and install `ctrld` is to use the following install
|
||||
sh -c 'sh -c "$(curl -sL https://api.controld.com/dl)"'
|
||||
```
|
||||
|
||||
Windows user and prefer Powershell (who doesn't)? No problem, execute this command instead in administrative cmd:
|
||||
Windows user and prefer Powershell (who doesn't)? No problem, execute this command instead in administrative PowerShell:
|
||||
```shell
|
||||
powershell -Command "(Invoke-WebRequest -Uri 'https://api.controld.com/dl' -UseBasicParsing).Content | Set-Content 'ctrld_install.bat'" && ctrld_install.bat
|
||||
(Invoke-WebRequest -Uri 'https://api.controld.com/dl/ps1' -UseBasicParsing).Content | Set-Content "$env:TEMPctrld_install.ps1"; Invoke-Expression "& '$env:TEMPctrld_install.ps1'"
|
||||
```
|
||||
|
||||
Or you can pull and run a Docker container from [Docker Hub](https://hub.docker.com/r/controldns/ctrld)
|
||||
```
|
||||
$ docker pull controldns/ctrld
|
||||
```shell
|
||||
docker run -d --name=ctrld -p 127.0.0.1:53:53/tcp -p 127.0.0.1:53:53/udp controldns/ctrld:latest
|
||||
```
|
||||
|
||||
## Download Manually
|
||||
@@ -67,20 +83,19 @@ Alternatively, if you know what you're doing you can download pre-compiled binar
|
||||
Lastly, you can build `ctrld` from source which requires `go1.21+`:
|
||||
|
||||
```shell
|
||||
$ go build ./cmd/ctrld
|
||||
go build ./cmd/ctrld
|
||||
```
|
||||
|
||||
or
|
||||
|
||||
```shell
|
||||
$ go install github.com/Control-D-Inc/ctrld/cmd/ctrld@latest
|
||||
go install github.com/Control-D-Inc/ctrld/cmd/ctrld@latest
|
||||
```
|
||||
|
||||
or
|
||||
|
||||
```
|
||||
$ docker build -t controldns/ctrld . -f docker/Dockerfile
|
||||
$ docker run -d --name=ctrld -p 53:53/tcp -p 53:53/udp controldns/ctrld --cd=RESOLVER_ID_GOES_HERE -vv
|
||||
```shell
|
||||
docker build -t controldns/ctrld . -f docker/Dockerfile
|
||||
```
|
||||
|
||||
|
||||
@@ -101,15 +116,16 @@ Usage:
|
||||
|
||||
Available Commands:
|
||||
run Run the DNS proxy server
|
||||
service Manage ctrld service
|
||||
start Quick start service and configure DNS on interface
|
||||
stop Quick stop service and remove DNS from interface
|
||||
restart Restart the ctrld service
|
||||
reload Reload the ctrld service
|
||||
status Show status of the ctrld service
|
||||
uninstall Stop and uninstall the ctrld service
|
||||
service Manage ctrld service
|
||||
clients Manage clients
|
||||
upgrade Upgrading ctrld to latest version
|
||||
log Manage runtime debug logs
|
||||
|
||||
Flags:
|
||||
-h, --help help for ctrld
|
||||
@@ -121,81 +137,99 @@ Use "ctrld [command] --help" for more information about a command.
|
||||
```
|
||||
|
||||
## Basic Run Mode
|
||||
To start the server with default configuration, simply run: `./ctrld run`. This will create a generic `ctrld.toml` file in the **working directory** and start the application in foreground.
|
||||
1. Start the server
|
||||
```
|
||||
$ sudo ./ctrld run
|
||||
This is the most basic way to run `ctrld`, in foreground mode. Unless you already have a config file, a default one will be generated.
|
||||
|
||||
### Command
|
||||
|
||||
Windows (Admin Shell)
|
||||
```shell
|
||||
ctrld.exe run
|
||||
```
|
||||
|
||||
2. Run a test query using a DNS client, for example, `dig`:
|
||||
Linux or Macos
|
||||
```shell
|
||||
sudo ctrld run
|
||||
```
|
||||
|
||||
You can then run a test query using a DNS client, for example, `dig`:
|
||||
```
|
||||
$ dig verify.controld.com @127.0.0.1 +short
|
||||
api.controld.com.
|
||||
147.185.34.1
|
||||
```
|
||||
|
||||
If `verify.controld.com` resolves, you're successfully using the default Control D upstream. From here, you can start editing the config file and go nuts with it. To enforce a new config, restart the server.
|
||||
If `verify.controld.com` resolves, you're successfully using the default Control D upstream. From here, you can start editing the config file that was generated. To enforce a new config, restart the server.
|
||||
|
||||
## Service Mode
|
||||
To run the application in service mode on any Windows, MacOS, Linux distibution or supported router, simply run: `./ctrld start` as system/root user. This will create a generic `ctrld.toml` file in the **user home** directory (on Windows) or `/etc/controld/` (almost everywhere else), start the system service, and configure the listener on the default network interface. Service will start on OS boot.
|
||||
This mode will run the application as a background system service on any Windows, MacOS, Linux, FreeBSD distribution or supported router. This will create a generic `ctrld.toml` file in the **C:\ControlD** directory (on Windows) or `/etc/controld/` (almost everywhere else), start the system service, and **configure the listener on all physical network interface**. Service will start on OS boot.
|
||||
|
||||
When Control D upstreams are used, `ctrld` willl [relay your network topology](https://docs.controld.com/docs/device-clients) to Control D (LAN IPs, MAC addresses, and hostnames), and you will be able to see your LAN devices in the web panel, view analytics and apply unique profiles to them.
|
||||
When Control D upstreams are used on a router type device, `ctrld` will [relay your network topology](https://docs.controld.com/docs/device-clients) to Control D (LAN IPs, MAC addresses, and hostnames), and you will be able to see your LAN devices in the web panel, view analytics and apply unique profiles to them.
|
||||
|
||||
In order to stop the service, and restore your DNS to original state, simply run `./ctrld stop`. If you wish to stop and uninstall the service permanently, run `./ctrld uninstall`.
|
||||
### Command
|
||||
|
||||
Windows (Admin Shell)
|
||||
```shell
|
||||
ctrld.exe start
|
||||
```
|
||||
|
||||
### Supported Routers
|
||||
You can run `ctrld` on any supported router, which will function similarly to the Service Mode mentioned above. The list of supported routers and firmware includes:
|
||||
- Asus Merlin
|
||||
- DD-WRT
|
||||
- Firewalla
|
||||
- FreshTomato
|
||||
- GL.iNet
|
||||
- OpenWRT
|
||||
- pfSense / OPNsense
|
||||
- Synology
|
||||
- Ubiquiti (UniFi, EdgeOS)
|
||||
Linux or Macos
|
||||
```
|
||||
sudo ctrld start
|
||||
```
|
||||
|
||||
`ctrld` will attempt to interface with dnsmasq whenever possible and set itself as the upstream, while running on port 5354. On FreeBSD based OSes, `ctrld` will terminate dnsmasq and unbound in order to be able to listen on port 53 directly.
|
||||
If `ctrld` is not in your system path (you installed it manually), you will need to run the above commands from the directory where you installed `ctrld`.
|
||||
|
||||
In order to stop the service, and restore your DNS to original state, simply run `ctrld stop`. If you wish to stop and uninstall the service permanently, run `ctrld uninstall`.
|
||||
|
||||
### Control D Auto Configuration
|
||||
Application can be started with a specific resolver config, instead of the default one. Simply supply your Resolver ID with a `--cd` flag, when using the `run` (foreground) or `start` (service) modes.
|
||||
## Unmanaged Service Mode
|
||||
This mode functions similarly to the "Service Mode" above except it will simply start a system service and the config defined listeners, but **will not make any changes to any network interfaces**. You can then set the `ctrld` listener(s) IP on the desired network interfaces manually.
|
||||
|
||||
The following command will start the application in foreground mode, using the free "p2" resolver, which blocks Ads & Trackers.
|
||||
### Command
|
||||
|
||||
```shell
|
||||
./ctrld run --cd p2
|
||||
```
|
||||
Windows (Admin Shell)
|
||||
```shell
|
||||
ctrld.exe service start
|
||||
```
|
||||
|
||||
Alternatively, you can use your own personal Control D Device resolver, and start the application in service mode. Your resolver ID is displayed on the "Show Resolvers" screen for the relevant Control D Device.
|
||||
|
||||
```shell
|
||||
./ctrld start --cd abcd1234
|
||||
```
|
||||
|
||||
Once you run the above commands (in service mode only), the following things will happen:
|
||||
- You resolver configuration will be fetched from the API, and config file templated with the resolver data
|
||||
- Application will start as a service, and keep running (even after reboot) until you run the `stop` or `uninstall` sub-commands
|
||||
- Your default network interface will be updated to use the listener started by the service
|
||||
- All OS DNS queries will be sent to the listener
|
||||
Linux or Macos
|
||||
```shell
|
||||
sudo ctrld service start
|
||||
```
|
||||
|
||||
# Configuration
|
||||
See [Configuration Docs](docs/config.md).
|
||||
`ctrld` can be configured in variety of different ways, which include: API, local config file or via cli launch args.
|
||||
|
||||
## Example
|
||||
- Start `listener.0` on 127.0.0.1:53
|
||||
- Accept queries from any source address
|
||||
- Send all queries to `upstream.0` via DoH protocol
|
||||
## API Based Auto Configuration
|
||||
Application can be started with a specific Control D resolver config, instead of the default one. Simply supply your Resolver ID with a `--cd` flag, when using the `start` (service) mode. In this mode, the application will automatically choose a non-conflicting IP and/or port and configure itself as the upstream to whatever process is running on port 53 (like dnsmasq or Windows DNS Server). This mode is used when the 1 liner installer command from the Control D onboarding guide is executed.
|
||||
|
||||
### Default Config
|
||||
The following command will use your own personal Control D Device resolver, and start the application in service mode. Your resolver ID is displayed on the "Show Resolvers" screen for the relevant Control D Endpoint.
|
||||
|
||||
Windows (Admin Shell)
|
||||
```shell
|
||||
ctrld.exe start --cd abcd1234
|
||||
```
|
||||
|
||||
Linux or Macos
|
||||
```shell
|
||||
sudo ctrld start --cd abcd1234
|
||||
```
|
||||
|
||||
Once you run the above command, the following things will happen:
|
||||
- You resolver configuration will be fetched from the API, and config file templated with the resolver data
|
||||
- Application will start as a service, and keep running (even after reboot) until you run the `stop` or `uninstall` sub-commands
|
||||
- All physical network interface will be updated to use the listener started by the service or dnsmasq upstream will be switched to `ctrld`
|
||||
- All DNS queries will be sent to the listener
|
||||
|
||||
## Manual Configuration
|
||||
`ctrld` is entirely config driven and can be configured in many different ways, please see [Configuration Docs](docs/config.md).
|
||||
|
||||
### Example
|
||||
```toml
|
||||
[listener]
|
||||
|
||||
[listener.0]
|
||||
ip = ""
|
||||
port = 0
|
||||
restricted = false
|
||||
ip = '0.0.0.0'
|
||||
port = 53
|
||||
|
||||
[network]
|
||||
|
||||
@@ -203,10 +237,6 @@ See [Configuration Docs](docs/config.md).
|
||||
cidrs = ["0.0.0.0/0"]
|
||||
name = "Network 0"
|
||||
|
||||
[service]
|
||||
log_level = "info"
|
||||
log_path = ""
|
||||
|
||||
[upstream]
|
||||
|
||||
[upstream.0]
|
||||
@@ -215,28 +245,26 @@ See [Configuration Docs](docs/config.md).
|
||||
name = "Control D - Anti-Malware"
|
||||
timeout = 5000
|
||||
type = "doh"
|
||||
|
||||
[upstream.1]
|
||||
bootstrap_ip = "76.76.2.11"
|
||||
endpoint = "p2.freedns.controld.com"
|
||||
name = "Control D - No Ads"
|
||||
timeout = 3000
|
||||
type = "doq"
|
||||
|
||||
```
|
||||
|
||||
`ctrld` will pick a working config for `listener.0` then writing the default config to disk for the first run.
|
||||
The above basic config will:
|
||||
- Start listener on 0.0.0.0:53
|
||||
- Accept queries from any source address
|
||||
- Send all queries to `https://freedns.controld.com/p1` using DoH protocol
|
||||
|
||||
## Advanced Configuration
|
||||
The above is the most basic example, which will work out of the box. If you're looking to do advanced configurations using policies, see [Configuration Docs](docs/config.md) for complete documentation of the config file.
|
||||
## CLI Args
|
||||
If you're unable to use a config file, `ctrld` can be be supplied with basic configuration via launch arguments, in [Ephemeral Mode](docs/ephemeral_mode.md).
|
||||
|
||||
You can also supply configuration via launch argeuments, in [Ephemeral Mode](docs/ephemeral_mode.md).
|
||||
### Example
|
||||
```
|
||||
ctrld run --listen=127.0.0.1:53 --primary_upstream=https://freedns.controld.com/p2 --secondary_upstream=10.0.10.1:53 --domains=*.company.int,very-secure.local --log /path/to/log.log
|
||||
```
|
||||
|
||||
The above will start a foreground process and:
|
||||
- Listen on `127.0.0.1:53` for DNS queries
|
||||
- Forward all queries to `https://freedns.controld.com/p2` using DoH protocol, while...
|
||||
- Excluding `*.company.int` and `very-secure.local` matching queries, that are forwarded to `10.0.10.1:53`
|
||||
- Write a debug log to `/path/to/log.log`
|
||||
|
||||
## Contributing
|
||||
See [Contribution Guideline](./docs/contributing.md)
|
||||
|
||||
## Roadmap
|
||||
The following functionality is on the roadmap and will be available in future releases.
|
||||
- DNS intercept mode
|
||||
- Direct listener mode
|
||||
- Support for more routers (let us know which ones)
|
||||
|
||||
4
client_info_darwin.go
Normal file
4
client_info_darwin.go
Normal file
@@ -0,0 +1,4 @@
|
||||
package ctrld
|
||||
|
||||
// SelfDiscover reports whether ctrld should only do self discover.
|
||||
func SelfDiscover() bool { return true }
|
||||
6
client_info_others.go
Normal file
6
client_info_others.go
Normal file
@@ -0,0 +1,6 @@
|
||||
//go:build !windows && !darwin
|
||||
|
||||
package ctrld
|
||||
|
||||
// SelfDiscover reports whether ctrld should only do self discover.
|
||||
func SelfDiscover() bool { return false }
|
||||
18
client_info_windows.go
Normal file
18
client_info_windows.go
Normal file
@@ -0,0 +1,18 @@
|
||||
package ctrld
|
||||
|
||||
import (
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
// isWindowsWorkStation reports whether ctrld was run on a Windows workstation machine.
|
||||
func isWindowsWorkStation() bool {
|
||||
// From https://learn.microsoft.com/en-us/windows/win32/api/winnt/ns-winnt-osversioninfoexa
|
||||
const VER_NT_WORKSTATION = 0x0000001
|
||||
osvi := windows.RtlGetVersion()
|
||||
return osvi.ProductType == VER_NT_WORKSTATION
|
||||
}
|
||||
|
||||
// SelfDiscover reports whether ctrld should only do self discover.
|
||||
func SelfDiscover() bool {
|
||||
return isWindowsWorkStation()
|
||||
}
|
||||
@@ -8,3 +8,8 @@ import (
|
||||
|
||||
// addExtraSplitDnsRule adds split DNS rule if present.
|
||||
func addExtraSplitDnsRule(_ *ctrld.Config) bool { return false }
|
||||
|
||||
// getActiveDirectoryDomain returns AD domain name of this computer.
|
||||
func getActiveDirectoryDomain() (string, error) {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
@@ -1,9 +1,14 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/microsoft/wmi/pkg/base/host"
|
||||
hh "github.com/microsoft/wmi/pkg/hardware/host"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld"
|
||||
)
|
||||
|
||||
@@ -21,29 +26,48 @@ func addExtraSplitDnsRule(cfg *ctrld.Config) bool {
|
||||
// Network rules are lowercase during toml config marshaling,
|
||||
// lowercase the domain here too for consistency.
|
||||
domain = strings.ToLower(domain)
|
||||
domainRuleAdded := addSplitDnsRule(cfg, domain)
|
||||
wildcardDomainRuleRuleAdded := addSplitDnsRule(cfg, "*."+strings.TrimPrefix(domain, "."))
|
||||
return domainRuleAdded || wildcardDomainRuleRuleAdded
|
||||
}
|
||||
|
||||
// addSplitDnsRule adds split-rule for given domain if there's no existed rule.
|
||||
// The return value indicates whether the split-rule was added or not.
|
||||
func addSplitDnsRule(cfg *ctrld.Config, domain string) bool {
|
||||
for n, lc := range cfg.Listener {
|
||||
if lc.Policy == nil {
|
||||
lc.Policy = &ctrld.ListenerPolicyConfig{}
|
||||
}
|
||||
domainRule := "*." + strings.TrimPrefix(domain, ".")
|
||||
for _, rule := range lc.Policy.Rules {
|
||||
if _, ok := rule[domainRule]; ok {
|
||||
mainLog.Load().Debug().Msgf("domain rule already exist for listener.%s", n)
|
||||
if _, ok := rule[domain]; ok {
|
||||
mainLog.Load().Debug().Msgf("split-rule %q already existed for listener.%s", domain, n)
|
||||
return false
|
||||
}
|
||||
}
|
||||
mainLog.Load().Debug().Msgf("adding active directory domain for listener.%s", n)
|
||||
lc.Policy.Rules = append(lc.Policy.Rules, ctrld.Rule{domainRule: []string{}})
|
||||
mainLog.Load().Debug().Msgf("adding split-rule %q for listener.%s", domain, n)
|
||||
lc.Policy.Rules = append(lc.Policy.Rules, ctrld.Rule{domain: []string{}})
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// getActiveDirectoryDomain returns AD domain name of this computer.
|
||||
func getActiveDirectoryDomain() (string, error) {
|
||||
cmd := "$obj = Get-WmiObject Win32_ComputerSystem; if ($obj.PartOfDomain) { $obj.Domain }"
|
||||
output, err := powershell(cmd)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to get domain name: %w, output:\n\n%s", err, string(output))
|
||||
log.SetOutput(io.Discard)
|
||||
defer log.SetOutput(os.Stderr)
|
||||
whost := host.NewWmiLocalHost()
|
||||
cs, err := hh.GetComputerSystem(whost)
|
||||
if cs != nil {
|
||||
defer cs.Close()
|
||||
}
|
||||
return string(output), nil
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
pod, err := cs.GetPropertyPartOfDomain()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if pod {
|
||||
return cs.GetPropertyDomain()
|
||||
}
|
||||
return "", nil
|
||||
}
|
||||
|
||||
71
cmd/cli/ad_windows_test.go
Normal file
71
cmd/cli/ad_windows_test.go
Normal file
@@ -0,0 +1,71 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld"
|
||||
"github.com/Control-D-Inc/ctrld/testhelper"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func Test_getActiveDirectoryDomain(t *testing.T) {
|
||||
start := time.Now()
|
||||
domain, err := getActiveDirectoryDomain()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Logf("Using Windows API takes: %d", time.Since(start).Milliseconds())
|
||||
|
||||
start = time.Now()
|
||||
domainPowershell, err := getActiveDirectoryDomainPowershell()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Logf("Using Powershell takes: %d", time.Since(start).Milliseconds())
|
||||
|
||||
if domain != domainPowershell {
|
||||
t.Fatalf("result mismatch, want: %v, got: %v", domainPowershell, domain)
|
||||
}
|
||||
}
|
||||
|
||||
func getActiveDirectoryDomainPowershell() (string, error) {
|
||||
cmd := "$obj = Get-WmiObject Win32_ComputerSystem; if ($obj.PartOfDomain) { $obj.Domain }"
|
||||
output, err := powershell(cmd)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to get domain name: %w, output:\n\n%s", err, string(output))
|
||||
}
|
||||
return string(output), nil
|
||||
}
|
||||
|
||||
func Test_addSplitDnsRule(t *testing.T) {
|
||||
newCfg := func(domains ...string) *ctrld.Config {
|
||||
cfg := testhelper.SampleConfig(t)
|
||||
lc := cfg.Listener["0"]
|
||||
for _, domain := range domains {
|
||||
lc.Policy.Rules = append(lc.Policy.Rules, ctrld.Rule{domain: []string{}})
|
||||
}
|
||||
return cfg
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
cfg *ctrld.Config
|
||||
domain string
|
||||
added bool
|
||||
}{
|
||||
{"added", newCfg(), "example.com", true},
|
||||
{"TLD existed", newCfg("example.com"), "*.example.com", true},
|
||||
{"wildcard existed", newCfg("*.example.com"), "example.com", true},
|
||||
{"not added TLD", newCfg("example.com", "*.example.com"), "example.com", false},
|
||||
{"not added wildcard", newCfg("example.com", "*.example.com"), "*.example.com", false},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
added := addSplitDnsRule(tc.cfg, tc.domain)
|
||||
assert.Equal(t, tc.added, added)
|
||||
})
|
||||
}
|
||||
}
|
||||
5
cmd/cli/cgo.go
Normal file
5
cmd/cli/cgo.go
Normal file
@@ -0,0 +1,5 @@
|
||||
//go:build cgo
|
||||
|
||||
package cli
|
||||
|
||||
const cgoEnabled = true
|
||||
1314
cmd/cli/cli.go
1314
cmd/cli/cli.go
File diff suppressed because it is too large
Load Diff
1397
cmd/cli/commands.go
Normal file
1397
cmd/cli/commands.go
Normal file
File diff suppressed because it is too large
Load Diff
@@ -25,6 +25,10 @@ func newControlClient(addr string) *controlClient {
|
||||
}
|
||||
|
||||
func (c *controlClient) post(path string, data io.Reader) (*http.Response, error) {
|
||||
// for log/send, set the timeout to 5 minutes
|
||||
if path == sendLogsPath {
|
||||
c.c.Timeout = time.Minute * 5
|
||||
}
|
||||
return c.c.Post("http://unix"+path, contentTypeJson, data)
|
||||
}
|
||||
|
||||
|
||||
@@ -3,6 +3,8 @@ package cli
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
@@ -25,8 +27,16 @@ const (
|
||||
deactivationPath = "/deactivation"
|
||||
cdPath = "/cd"
|
||||
ifacePath = "/iface"
|
||||
viewLogsPath = "/log/view"
|
||||
sendLogsPath = "/log/send"
|
||||
)
|
||||
|
||||
type ifaceResponse struct {
|
||||
Name string `json:"name"`
|
||||
All bool `json:"all"`
|
||||
OK bool `json:"ok"`
|
||||
}
|
||||
|
||||
type controlServer struct {
|
||||
server *http.Server
|
||||
mux *http.ServeMux
|
||||
@@ -69,33 +79,81 @@ func (s *controlServer) register(pattern string, handler http.Handler) {
|
||||
|
||||
func (p *prog) registerControlServerHandler() {
|
||||
p.cs.register(listClientsPath, http.HandlerFunc(func(w http.ResponseWriter, request *http.Request) {
|
||||
mainLog.Load().Debug().Msg("handling list clients request")
|
||||
|
||||
clients := p.ciTable.ListClients()
|
||||
mainLog.Load().Debug().Int("client_count", len(clients)).Msg("retrieved clients list")
|
||||
|
||||
sort.Slice(clients, func(i, j int) bool {
|
||||
return clients[i].IP.Less(clients[j].IP)
|
||||
})
|
||||
mainLog.Load().Debug().Msg("sorted clients by IP address")
|
||||
|
||||
if p.metricsQueryStats.Load() {
|
||||
for _, client := range clients {
|
||||
mainLog.Load().Debug().Msg("metrics query stats enabled, collecting query counts")
|
||||
|
||||
for idx, client := range clients {
|
||||
mainLog.Load().Debug().
|
||||
Int("index", idx).
|
||||
Str("ip", client.IP.String()).
|
||||
Str("mac", client.Mac).
|
||||
Str("hostname", client.Hostname).
|
||||
Msg("processing client metrics")
|
||||
|
||||
client.IncludeQueryCount = true
|
||||
dm := &dto.Metric{}
|
||||
|
||||
if statsClientQueriesCount.MetricVec == nil {
|
||||
mainLog.Load().Debug().
|
||||
Str("client_ip", client.IP.String()).
|
||||
Msg("skipping metrics collection: MetricVec is nil")
|
||||
continue
|
||||
}
|
||||
|
||||
m, err := statsClientQueriesCount.MetricVec.GetMetricWithLabelValues(
|
||||
client.IP.String(),
|
||||
client.Mac,
|
||||
client.Hostname,
|
||||
)
|
||||
if err != nil {
|
||||
mainLog.Load().Debug().Err(err).Msgf("could not get metrics for client: %v", client)
|
||||
mainLog.Load().Debug().
|
||||
Err(err).
|
||||
Str("client_ip", client.IP.String()).
|
||||
Str("mac", client.Mac).
|
||||
Str("hostname", client.Hostname).
|
||||
Msg("failed to get metrics for client")
|
||||
continue
|
||||
}
|
||||
if err := m.Write(dm); err == nil {
|
||||
|
||||
if err := m.Write(dm); err == nil && dm.Counter != nil {
|
||||
client.QueryCount = int64(dm.Counter.GetValue())
|
||||
mainLog.Load().Debug().
|
||||
Str("client_ip", client.IP.String()).
|
||||
Int64("query_count", client.QueryCount).
|
||||
Msg("successfully collected query count")
|
||||
} else if err != nil {
|
||||
mainLog.Load().Debug().
|
||||
Err(err).
|
||||
Str("client_ip", client.IP.String()).
|
||||
Msg("failed to write metric")
|
||||
}
|
||||
}
|
||||
} else {
|
||||
mainLog.Load().Debug().Msg("metrics query stats disabled, skipping query counts")
|
||||
}
|
||||
|
||||
if err := json.NewEncoder(w).Encode(&clients); err != nil {
|
||||
mainLog.Load().Error().
|
||||
Err(err).
|
||||
Int("client_count", len(clients)).
|
||||
Msg("failed to encode clients response")
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
mainLog.Load().Debug().
|
||||
Int("client_count", len(clients)).
|
||||
Msg("successfully sent clients list response")
|
||||
}))
|
||||
p.cs.register(startedPath, http.HandlerFunc(func(w http.ResponseWriter, request *http.Request) {
|
||||
select {
|
||||
@@ -170,7 +228,7 @@ func (p *prog) registerControlServerHandler() {
|
||||
}
|
||||
|
||||
// If pin code not set, allowing deactivation.
|
||||
if deactivationPinNotSet() {
|
||||
if !deactivationPinSet() {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
return
|
||||
}
|
||||
@@ -186,6 +244,10 @@ func (p *prog) registerControlServerHandler() {
|
||||
switch req.Pin {
|
||||
case cdDeactivationPin.Load():
|
||||
code = http.StatusOK
|
||||
select {
|
||||
case p.pinCodeValidCh <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
case defaultDeactivationPin:
|
||||
// If the pin code was set, but users do not provide --pin, return proper code to client.
|
||||
code = http.StatusBadRequest
|
||||
@@ -201,15 +263,76 @@ func (p *prog) registerControlServerHandler() {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
}))
|
||||
p.cs.register(ifacePath, http.HandlerFunc(func(w http.ResponseWriter, request *http.Request) {
|
||||
res := &ifaceResponse{Name: iface}
|
||||
// p.setDNS is only called when running as a service
|
||||
if !service.Interactive() {
|
||||
<-p.csSetDnsDone
|
||||
if p.csSetDnsOk {
|
||||
w.Write([]byte(iface))
|
||||
return
|
||||
res.Name = p.runningIface
|
||||
res.All = p.requiredMultiNICsConfig
|
||||
res.OK = true
|
||||
}
|
||||
}
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
if err := json.NewEncoder(w).Encode(res); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
http.Error(w, fmt.Sprintf("could not marshal iface data: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
}))
|
||||
p.cs.register(viewLogsPath, http.HandlerFunc(func(w http.ResponseWriter, request *http.Request) {
|
||||
lr, err := p.logReader()
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
defer lr.r.Close()
|
||||
if lr.size == 0 {
|
||||
w.WriteHeader(http.StatusMovedPermanently)
|
||||
return
|
||||
}
|
||||
data, err := io.ReadAll(lr.r)
|
||||
if err != nil {
|
||||
http.Error(w, fmt.Sprintf("could not read log: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
if err := json.NewEncoder(w).Encode(&logViewResponse{Data: string(data)}); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
http.Error(w, fmt.Sprintf("could not marshal log data: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
}))
|
||||
p.cs.register(sendLogsPath, http.HandlerFunc(func(w http.ResponseWriter, request *http.Request) {
|
||||
if time.Since(p.internalLogSent) < logWriterSentInterval {
|
||||
w.WriteHeader(http.StatusServiceUnavailable)
|
||||
return
|
||||
}
|
||||
r, err := p.logReader()
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if r.size == 0 {
|
||||
w.WriteHeader(http.StatusMovedPermanently)
|
||||
return
|
||||
}
|
||||
req := &controld.LogsRequest{
|
||||
UID: cdUID,
|
||||
Data: r.r,
|
||||
}
|
||||
mainLog.Load().Debug().Msg("sending log file to ControlD server")
|
||||
resp := logSentResponse{Size: r.size}
|
||||
if err := controld.SendLogs(req, cdDev); err != nil {
|
||||
mainLog.Load().Error().Msgf("could not send log file to ControlD server: %v", err)
|
||||
resp.Error = err.Error()
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
} else {
|
||||
mainLog.Load().Debug().Msg("sending log file successfully")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}
|
||||
if err := json.NewEncoder(w).Encode(&resp); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
}
|
||||
p.internalLogSent = time.Now()
|
||||
}))
|
||||
}
|
||||
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os/exec"
|
||||
"runtime"
|
||||
"slices"
|
||||
"strconv"
|
||||
@@ -24,6 +25,7 @@ import (
|
||||
"github.com/Control-D-Inc/ctrld/internal/controld"
|
||||
"github.com/Control-D-Inc/ctrld/internal/dnscache"
|
||||
ctrldnet "github.com/Control-D-Inc/ctrld/internal/net"
|
||||
"github.com/Control-D-Inc/ctrld/internal/router"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -41,7 +43,7 @@ const (
|
||||
var osUpstreamConfig = &ctrld.UpstreamConfig{
|
||||
Name: "OS resolver",
|
||||
Type: ctrld.ResolverTypeOS,
|
||||
Timeout: 2000,
|
||||
Timeout: 3000,
|
||||
}
|
||||
|
||||
var privateUpstreamConfig = &ctrld.UpstreamConfig{
|
||||
@@ -50,6 +52,12 @@ var privateUpstreamConfig = &ctrld.UpstreamConfig{
|
||||
Timeout: 2000,
|
||||
}
|
||||
|
||||
var localUpstreamConfig = &ctrld.UpstreamConfig{
|
||||
Name: "Local resolver",
|
||||
Type: ctrld.ResolverTypeLocal,
|
||||
Timeout: 2000,
|
||||
}
|
||||
|
||||
// proxyRequest contains data for proxying a DNS query to upstream.
|
||||
type proxyRequest struct {
|
||||
msg *dns.Msg
|
||||
@@ -106,11 +114,18 @@ func (p *prog) serveDNS(listenerNum string) error {
|
||||
go p.detectLoop(m)
|
||||
q := m.Question[0]
|
||||
domain := canonicalName(q.Name)
|
||||
if domain == selfCheckInternalTestDomain {
|
||||
switch {
|
||||
case domain == "":
|
||||
answer := new(dns.Msg)
|
||||
answer.SetRcode(m, dns.RcodeFormatError)
|
||||
_ = w.WriteMsg(answer)
|
||||
return
|
||||
case domain == selfCheckInternalTestDomain:
|
||||
answer := resolveInternalDomainTestQuery(ctx, domain, m)
|
||||
_ = w.WriteMsg(answer)
|
||||
return
|
||||
}
|
||||
|
||||
if _, ok := p.cacheFlushDomainsMap[domain]; ok && p.cache != nil {
|
||||
p.cache.Purge()
|
||||
ctrld.Log(ctx, mainLog.Load().Debug(), "received query %q, local cache is purged", domain)
|
||||
@@ -192,8 +207,8 @@ func (p *prog) serveDNS(listenerNum string) error {
|
||||
return nil
|
||||
})
|
||||
}
|
||||
// When we spawn a listener on 127.0.0.1, also spawn listeners on the RFC1918
|
||||
// addresses of the machine. So ctrld could receive queries from LAN clients.
|
||||
// When we spawn a listener on 127.0.0.1, also spawn listeners on the RFC1918 addresses of the machine
|
||||
// if explicitly set via setting rfc1918 flag, so ctrld could receive queries from LAN clients.
|
||||
if needRFC1918Listeners(listenerConfig) {
|
||||
g.Go(func() error {
|
||||
for _, addr := range ctrld.Rfc1918Addresses() {
|
||||
@@ -411,21 +426,20 @@ func (p *prog) proxy(ctx context.Context, req *proxyRequest) *proxyResponse {
|
||||
serveStaleCache := p.cache != nil && p.cfg.Service.CacheServeStale
|
||||
upstreamConfigs := p.upstreamConfigsFromUpstreamNumbers(upstreams)
|
||||
|
||||
leaked := false
|
||||
// If ctrld is going to leak query to OS resolver, check remote upstream in background,
|
||||
// so ctrld could be back to normal operation as long as the network is back online.
|
||||
if len(upstreamConfigs) > 0 && p.leakingQuery.Load() {
|
||||
for n, uc := range upstreamConfigs {
|
||||
go p.checkUpstream(upstreams[n], uc)
|
||||
}
|
||||
upstreamConfigs = nil
|
||||
leaked = true
|
||||
ctrld.Log(ctx, mainLog.Load().Debug(), "%v is down, leaking query to OS resolver", upstreams)
|
||||
}
|
||||
|
||||
if len(upstreamConfigs) == 0 {
|
||||
upstreamConfigs = []*ctrld.UpstreamConfig{osUpstreamConfig}
|
||||
upstreams = []string{upstreamOS}
|
||||
// For OS resolver, local addresses are ignored to prevent possible looping.
|
||||
// However, on Active Directory Domain Controller, where it has local DNS server
|
||||
// running and listening on local addresses, these local addresses must be used
|
||||
// as nameservers, so queries for ADDC could be resolved as expected.
|
||||
if p.isAdDomainQuery(req.msg) {
|
||||
ctrld.Log(ctx, mainLog.Load().Debug(),
|
||||
"AD domain query detected for %s in domain %s",
|
||||
req.msg.Question[0].Name, p.adDomain)
|
||||
upstreamConfigs = []*ctrld.UpstreamConfig{localUpstreamConfig}
|
||||
upstreams = []string{upstreamOSLocal}
|
||||
}
|
||||
}
|
||||
|
||||
res := &proxyResponse{}
|
||||
@@ -438,13 +452,14 @@ func (p *prog) proxy(ctx context.Context, req *proxyRequest) *proxyResponse {
|
||||
// 4. Try remote upstream.
|
||||
isLanOrPtrQuery := false
|
||||
if req.ufr.matched {
|
||||
if leaked {
|
||||
ctrld.Log(ctx, mainLog.Load().Debug(), "%s, %s, %s -> %v (leaked)", req.ufr.matchedPolicy, req.ufr.matchedNetwork, req.ufr.matchedRule, upstreams)
|
||||
} else {
|
||||
ctrld.Log(ctx, mainLog.Load().Debug(), "%s, %s, %s -> %v", req.ufr.matchedPolicy, req.ufr.matchedNetwork, req.ufr.matchedRule, upstreams)
|
||||
}
|
||||
ctrld.Log(ctx, mainLog.Load().Debug(), "%s, %s, %s -> %v", req.ufr.matchedPolicy, req.ufr.matchedNetwork, req.ufr.matchedRule, upstreams)
|
||||
} else {
|
||||
switch {
|
||||
case isSrvLanLookup(req.msg):
|
||||
upstreams = []string{upstreamOS}
|
||||
upstreamConfigs = []*ctrld.UpstreamConfig{osUpstreamConfig}
|
||||
ctx = ctrld.LanQueryCtx(ctx)
|
||||
ctrld.Log(ctx, mainLog.Load().Debug(), "SRV record lookup, using upstreams: %v", upstreams)
|
||||
case isPrivatePtrLookup(req.msg):
|
||||
isLanOrPtrQuery = true
|
||||
if answer := p.proxyPrivatePtrLookup(ctx, req.msg); answer != nil {
|
||||
@@ -452,7 +467,8 @@ func (p *prog) proxy(ctx context.Context, req *proxyRequest) *proxyResponse {
|
||||
res.clientInfo = true
|
||||
return res
|
||||
}
|
||||
upstreams, upstreamConfigs = p.upstreamsAndUpstreamConfigForLanAndPtr(upstreams, upstreamConfigs)
|
||||
upstreams, upstreamConfigs = p.upstreamsAndUpstreamConfigForPtr(upstreams, upstreamConfigs)
|
||||
ctx = ctrld.LanQueryCtx(ctx)
|
||||
ctrld.Log(ctx, mainLog.Load().Debug(), "private PTR lookup, using upstreams: %v", upstreams)
|
||||
case isLanHostnameQuery(req.msg):
|
||||
isLanOrPtrQuery = true
|
||||
@@ -461,7 +477,9 @@ func (p *prog) proxy(ctx context.Context, req *proxyRequest) *proxyResponse {
|
||||
res.clientInfo = true
|
||||
return res
|
||||
}
|
||||
upstreams, upstreamConfigs = p.upstreamsAndUpstreamConfigForLanAndPtr(upstreams, upstreamConfigs)
|
||||
upstreams = []string{upstreamOS}
|
||||
upstreamConfigs = []*ctrld.UpstreamConfig{osUpstreamConfig}
|
||||
ctx = ctrld.LanQueryCtx(ctx)
|
||||
ctrld.Log(ctx, mainLog.Load().Debug(), "lan hostname lookup, using upstreams: %v", upstreams)
|
||||
default:
|
||||
ctrld.Log(ctx, mainLog.Load().Debug(), "no explicit policy matched, using default routing -> %v", upstreams)
|
||||
@@ -476,7 +494,7 @@ func (p *prog) proxy(ctx context.Context, req *proxyRequest) *proxyResponse {
|
||||
continue
|
||||
}
|
||||
answer := cachedValue.Msg.Copy()
|
||||
answer.SetRcode(req.msg, answer.Rcode)
|
||||
ctrld.SetCacheReply(answer, req.msg, answer.Rcode)
|
||||
now := time.Now()
|
||||
if cachedValue.Expire.After(now) {
|
||||
ctrld.Log(ctx, mainLog.Load().Debug(), "hit cached response")
|
||||
@@ -488,59 +506,68 @@ func (p *prog) proxy(ctx context.Context, req *proxyRequest) *proxyResponse {
|
||||
staleAnswer = answer
|
||||
}
|
||||
}
|
||||
resolve1 := func(n int, upstreamConfig *ctrld.UpstreamConfig, msg *dns.Msg) (*dns.Msg, error) {
|
||||
ctrld.Log(ctx, mainLog.Load().Debug(), "sending query to %s: %s", upstreams[n], upstreamConfig.Name)
|
||||
resolve1 := func(upstream string, upstreamConfig *ctrld.UpstreamConfig, msg *dns.Msg) (*dns.Msg, error) {
|
||||
ctrld.Log(ctx, mainLog.Load().Debug(), "sending query to %s: %s", upstream, upstreamConfig.Name)
|
||||
dnsResolver, err := ctrld.NewResolver(upstreamConfig)
|
||||
if err != nil {
|
||||
ctrld.Log(ctx, mainLog.Load().Error().Err(err), "failed to create resolver")
|
||||
return nil, err
|
||||
}
|
||||
resolveCtx, cancel := context.WithCancel(ctx)
|
||||
resolveCtx, cancel := upstreamConfig.Context(ctx)
|
||||
defer cancel()
|
||||
if upstreamConfig.Timeout > 0 {
|
||||
timeoutCtx, cancel := context.WithTimeout(resolveCtx, time.Millisecond*time.Duration(upstreamConfig.Timeout))
|
||||
defer cancel()
|
||||
resolveCtx = timeoutCtx
|
||||
}
|
||||
return dnsResolver.Resolve(resolveCtx, msg)
|
||||
}
|
||||
resolve := func(n int, upstreamConfig *ctrld.UpstreamConfig, msg *dns.Msg) *dns.Msg {
|
||||
resolve := func(upstream string, upstreamConfig *ctrld.UpstreamConfig, msg *dns.Msg) *dns.Msg {
|
||||
if upstreamConfig.UpstreamSendClientInfo() && req.ci != nil {
|
||||
ctrld.Log(ctx, mainLog.Load().Debug(), "including client info with the request")
|
||||
ctx = context.WithValue(ctx, ctrld.ClientInfoCtxKey{}, req.ci)
|
||||
}
|
||||
answer, err := resolve1(n, upstreamConfig, msg)
|
||||
answer, err := resolve1(upstream, upstreamConfig, msg)
|
||||
// if we have an answer, we should reset the failure count
|
||||
// we dont use reset here since we dont want to prevent failure counts from being incremented
|
||||
if answer != nil {
|
||||
p.um.mu.Lock()
|
||||
p.um.failureReq[upstream] = 0
|
||||
p.um.down[upstream] = false
|
||||
p.um.mu.Unlock()
|
||||
return answer
|
||||
}
|
||||
|
||||
ctrld.Log(ctx, mainLog.Load().Error().Err(err), "failed to resolve query")
|
||||
|
||||
// increase failure count when there is no answer
|
||||
// rehardless of what kind of error we get
|
||||
p.um.increaseFailureCount(upstream)
|
||||
|
||||
if err != nil {
|
||||
ctrld.Log(ctx, mainLog.Load().Error().Err(err), "failed to resolve query")
|
||||
isNetworkErr := errNetworkError(err)
|
||||
if isNetworkErr {
|
||||
p.um.increaseFailureCount(upstreams[n])
|
||||
if p.um.isDown(upstreams[n]) {
|
||||
go p.checkUpstream(upstreams[n], upstreamConfig)
|
||||
}
|
||||
}
|
||||
// For timeout error (i.e: context deadline exceed), force re-bootstrapping.
|
||||
var e net.Error
|
||||
if errors.As(err, &e) && e.Timeout() {
|
||||
upstreamConfig.ReBootstrap()
|
||||
}
|
||||
return nil
|
||||
// For network error, turn ipv6 off if enabled.
|
||||
if ctrld.HasIPv6() && (errUrlNetworkError(err) || errNetworkError(err)) {
|
||||
ctrld.DisableIPv6()
|
||||
}
|
||||
}
|
||||
return answer
|
||||
|
||||
return nil
|
||||
}
|
||||
for n, upstreamConfig := range upstreamConfigs {
|
||||
if upstreamConfig == nil {
|
||||
continue
|
||||
}
|
||||
logger := mainLog.Load().Debug().
|
||||
Str("upstream", upstreamConfig.String()).
|
||||
Str("query", req.msg.Question[0].Name).
|
||||
Bool("is_ad_query", p.isAdDomainQuery(req.msg)).
|
||||
Bool("is_lan_query", isLanOrPtrQuery)
|
||||
|
||||
if p.isLoop(upstreamConfig) {
|
||||
mainLog.Load().Warn().Msgf("dns loop detected, upstream: %q, endpoint: %q", upstreamConfig.Name, upstreamConfig.Endpoint)
|
||||
ctrld.Log(ctx, logger, "DNS loop detected")
|
||||
continue
|
||||
}
|
||||
if p.um.isDown(upstreams[n]) {
|
||||
ctrld.Log(ctx, mainLog.Load().Warn(), "%s is down", upstreams[n])
|
||||
continue
|
||||
}
|
||||
answer := resolve(n, upstreamConfig, req.msg)
|
||||
answer := resolve(upstreams[n], upstreamConfig, req.msg)
|
||||
if answer == nil {
|
||||
if serveStaleCache && staleAnswer != nil {
|
||||
ctrld.Log(ctx, mainLog.Load().Debug(), "serving stale cached response")
|
||||
@@ -587,21 +614,50 @@ func (p *prog) proxy(ctx context.Context, req *proxyRequest) *proxyResponse {
|
||||
return res
|
||||
}
|
||||
ctrld.Log(ctx, mainLog.Load().Error(), "all %v endpoints failed", upstreams)
|
||||
if cdUID != "" && p.leakOnUpstreamFailure() {
|
||||
p.leakingQueryMu.Lock()
|
||||
if !p.leakingQueryWasRun {
|
||||
p.leakingQueryWasRun = true
|
||||
go p.performLeakingQuery()
|
||||
|
||||
// if we have no healthy upstreams, trigger recovery flow
|
||||
if p.leakOnUpstreamFailure() {
|
||||
if p.um.countHealthy(upstreams) == 0 {
|
||||
p.recoveryCancelMu.Lock()
|
||||
if p.recoveryCancel == nil {
|
||||
var reason RecoveryReason
|
||||
if upstreams[0] == upstreamOS {
|
||||
reason = RecoveryReasonOSFailure
|
||||
} else {
|
||||
reason = RecoveryReasonRegularFailure
|
||||
}
|
||||
mainLog.Load().Debug().Msgf("No healthy upstreams, triggering recovery with reason: %v", reason)
|
||||
go p.handleRecovery(reason)
|
||||
} else {
|
||||
mainLog.Load().Debug().Msg("Recovery already in progress; skipping duplicate trigger from down detection")
|
||||
}
|
||||
p.recoveryCancelMu.Unlock()
|
||||
} else {
|
||||
mainLog.Load().Debug().Msg("One upstream is down but at least one is healthy; skipping recovery trigger")
|
||||
}
|
||||
|
||||
// attempt query to OS resolver while as a retry catch all
|
||||
// we dont want this to happen if leakOnUpstreamFailure is false
|
||||
if upstreams[0] != upstreamOS {
|
||||
ctrld.Log(ctx, mainLog.Load().Debug(), "attempting query to OS resolver as a retry catch all")
|
||||
answer := resolve(upstreamOS, osUpstreamConfig, req.msg)
|
||||
if answer != nil {
|
||||
ctrld.Log(ctx, mainLog.Load().Debug(), "OS resolver retry query successful")
|
||||
res.answer = answer
|
||||
res.upstream = osUpstreamConfig.Endpoint
|
||||
return res
|
||||
}
|
||||
ctrld.Log(ctx, mainLog.Load().Debug(), "OS resolver retry query failed")
|
||||
}
|
||||
p.leakingQueryMu.Unlock()
|
||||
}
|
||||
|
||||
answer := new(dns.Msg)
|
||||
answer.SetRcode(req.msg, dns.RcodeServerFailure)
|
||||
res.answer = answer
|
||||
return res
|
||||
}
|
||||
|
||||
func (p *prog) upstreamsAndUpstreamConfigForLanAndPtr(upstreams []string, upstreamConfigs []*ctrld.UpstreamConfig) ([]string, []*ctrld.UpstreamConfig) {
|
||||
func (p *prog) upstreamsAndUpstreamConfigForPtr(upstreams []string, upstreamConfigs []*ctrld.UpstreamConfig) ([]string, []*ctrld.UpstreamConfig) {
|
||||
if len(p.localUpstreams) > 0 {
|
||||
tmp := make([]string, 0, len(p.localUpstreams)+len(upstreams))
|
||||
tmp = append(tmp, p.localUpstreams...)
|
||||
@@ -620,6 +676,14 @@ func (p *prog) upstreamConfigsFromUpstreamNumbers(upstreams []string) []*ctrld.U
|
||||
return upstreamConfigs
|
||||
}
|
||||
|
||||
func (p *prog) isAdDomainQuery(msg *dns.Msg) bool {
|
||||
if p.adDomain == "" {
|
||||
return false
|
||||
}
|
||||
cDomainName := canonicalName(msg.Question[0].Name)
|
||||
return dns.IsSubDomain(p.adDomain, cDomainName)
|
||||
}
|
||||
|
||||
// canonicalName returns canonical name from FQDN with "." trimmed.
|
||||
func canonicalName(fqdn string) string {
|
||||
q := strings.TrimSpace(fqdn)
|
||||
@@ -916,18 +980,6 @@ func (p *prog) selfUninstallCoolOfPeriod() {
|
||||
p.selfUninstallMu.Unlock()
|
||||
}
|
||||
|
||||
// performLeakingQuery performs necessary works to leak queries to OS resolver.
|
||||
func (p *prog) performLeakingQuery() {
|
||||
mainLog.Load().Warn().Msg("leaking query to OS resolver")
|
||||
// Signal dns watchers to stop, so changes made below won't be reverted.
|
||||
p.leakingQuery.Store(true)
|
||||
p.resetDNS()
|
||||
ns := ctrld.InitializeOsResolver()
|
||||
mainLog.Load().Debug().Msgf("re-initialized OS resolver with nameservers: %v", ns)
|
||||
p.dnsWg.Wait()
|
||||
p.setDNS()
|
||||
}
|
||||
|
||||
// forceFetchingAPI sends signal to force syncing API config if run in cd mode,
|
||||
// and the domain == "cdUID.verify.controld.com"
|
||||
func (p *prog) forceFetchingAPI(domain string) {
|
||||
@@ -984,8 +1036,10 @@ func (p *prog) queryFromSelf(ip string) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// needRFC1918Listeners reports whether ctrld need to spawn listener for RFC 1918 addresses.
|
||||
// This is helpful for non-desktop platforms to receive queries from LAN clients.
|
||||
func needRFC1918Listeners(lc *ctrld.ListenerConfig) bool {
|
||||
return lc.IP == "127.0.0.1" && lc.Port == 53
|
||||
return rfc1918 && lc.IP == "127.0.0.1" && lc.Port == 53
|
||||
}
|
||||
|
||||
// ipFromARPA parses a FQDN arpa domain and return the IP address if valid.
|
||||
@@ -1053,10 +1107,25 @@ func isLanHostnameQuery(m *dns.Msg) bool {
|
||||
default:
|
||||
return false
|
||||
}
|
||||
name := strings.TrimSuffix(q.Name, ".")
|
||||
return isLanHostname(q.Name)
|
||||
}
|
||||
|
||||
// isSrvLanLookup reports whether DNS message is an SRV query of a LAN hostname.
|
||||
func isSrvLanLookup(m *dns.Msg) bool {
|
||||
if m == nil || len(m.Question) == 0 {
|
||||
return false
|
||||
}
|
||||
q := m.Question[0]
|
||||
return q.Qtype == dns.TypeSRV && isLanHostname(q.Name)
|
||||
}
|
||||
|
||||
// isLanHostname reports whether name is a LAN hostname.
|
||||
func isLanHostname(name string) bool {
|
||||
name = strings.TrimSuffix(name, ".")
|
||||
return !strings.Contains(name, ".") ||
|
||||
strings.HasSuffix(name, ".domain") ||
|
||||
strings.HasSuffix(name, ".lan")
|
||||
strings.HasSuffix(name, ".lan") ||
|
||||
strings.HasSuffix(name, ".local")
|
||||
}
|
||||
|
||||
// isWanClient reports whether the input is a WAN address.
|
||||
@@ -1089,3 +1158,470 @@ func resolveInternalDomainTestQuery(ctx context.Context, domain string, m *dns.M
|
||||
answer.SetReply(m)
|
||||
return answer
|
||||
}
|
||||
|
||||
// FlushDNSCache flushes the DNS cache on macOS.
|
||||
func FlushDNSCache() error {
|
||||
// if not macOS, return
|
||||
if runtime.GOOS != "darwin" {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Flush the DNS cache via mDNSResponder.
|
||||
// This is typically needed on modern macOS systems.
|
||||
if out, err := exec.Command("killall", "-HUP", "mDNSResponder").CombinedOutput(); err != nil {
|
||||
return fmt.Errorf("failed to flush mDNSResponder: %w, output: %s", err, string(out))
|
||||
}
|
||||
|
||||
// Optionally, flush the directory services cache.
|
||||
if out, err := exec.Command("dscacheutil", "-flushcache").CombinedOutput(); err != nil {
|
||||
return fmt.Errorf("failed to flush dscacheutil: %w, output: %s", err, string(out))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// monitorNetworkChanges starts monitoring for network interface changes
|
||||
func (p *prog) monitorNetworkChanges() error {
|
||||
mon, err := netmon.New(func(format string, args ...any) {
|
||||
// Always fetch the latest logger (and inject the prefix)
|
||||
mainLog.Load().Printf("netmon: "+format, args...)
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("creating network monitor: %w", err)
|
||||
}
|
||||
|
||||
mon.RegisterChangeCallback(func(delta *netmon.ChangeDelta) {
|
||||
// Get map of valid interfaces
|
||||
validIfaces := validInterfacesMap()
|
||||
|
||||
isMajorChange := mon.IsMajorChangeFrom(delta.Old, delta.New)
|
||||
|
||||
mainLog.Load().Debug().
|
||||
Interface("old_state", delta.Old).
|
||||
Interface("new_state", delta.New).
|
||||
Bool("is_major_change", isMajorChange).
|
||||
Msg("Network change detected")
|
||||
|
||||
changed := false
|
||||
activeInterfaceExists := false
|
||||
var changeIPs []netip.Prefix
|
||||
// Check each valid interface for changes
|
||||
for ifaceName := range validIfaces {
|
||||
oldIface, oldExists := delta.Old.Interface[ifaceName]
|
||||
newIface, newExists := delta.New.Interface[ifaceName]
|
||||
if !newExists {
|
||||
continue
|
||||
}
|
||||
|
||||
oldIPs := delta.Old.InterfaceIPs[ifaceName]
|
||||
newIPs := delta.New.InterfaceIPs[ifaceName]
|
||||
|
||||
// if a valid interface did not exist in old
|
||||
// check that its up and has usable IPs
|
||||
if !oldExists {
|
||||
// The interface is new (was not present in the old state).
|
||||
usableNewIPs := filterUsableIPs(newIPs)
|
||||
if newIface.IsUp() && len(usableNewIPs) > 0 {
|
||||
changed = true
|
||||
changeIPs = usableNewIPs
|
||||
mainLog.Load().Debug().
|
||||
Str("interface", ifaceName).
|
||||
Interface("new_ips", usableNewIPs).
|
||||
Msg("Interface newly appeared (was not present in old state)")
|
||||
break
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// Filter new IPs to only those that are usable.
|
||||
usableNewIPs := filterUsableIPs(newIPs)
|
||||
|
||||
// Check if interface is up and has usable IPs.
|
||||
if newIface.IsUp() && len(usableNewIPs) > 0 {
|
||||
activeInterfaceExists = true
|
||||
}
|
||||
|
||||
// Compare interface states and IPs (interfaceIPsEqual will itself filter the IPs).
|
||||
if !interfaceStatesEqual(&oldIface, &newIface) || !interfaceIPsEqual(oldIPs, newIPs) {
|
||||
if newIface.IsUp() && len(usableNewIPs) > 0 {
|
||||
changed = true
|
||||
changeIPs = usableNewIPs
|
||||
mainLog.Load().Debug().
|
||||
Str("interface", ifaceName).
|
||||
Interface("old_ips", oldIPs).
|
||||
Interface("new_ips", usableNewIPs).
|
||||
Msg("Interface state or IPs changed")
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// if the default route changed, set changed to true
|
||||
if delta.New.DefaultRouteInterface != delta.Old.DefaultRouteInterface {
|
||||
changed = true
|
||||
mainLog.Load().Debug().Msgf("Default route changed from %s to %s", delta.Old.DefaultRouteInterface, delta.New.DefaultRouteInterface)
|
||||
}
|
||||
|
||||
if !changed {
|
||||
mainLog.Load().Debug().Msg("Ignoring interface change - no valid interfaces affected")
|
||||
// check if the default IPs are still on an interface that is up
|
||||
ValidateDefaultLocalIPsFromDelta(delta.New)
|
||||
return
|
||||
}
|
||||
|
||||
if !activeInterfaceExists {
|
||||
mainLog.Load().Debug().Msg("No active interfaces found, skipping reinitialization")
|
||||
return
|
||||
}
|
||||
|
||||
// Get IPs from default route interface in new state
|
||||
selfIP := defaultRouteIP()
|
||||
|
||||
// Ensure that selfIP is an IPv4 address.
|
||||
// If defaultRouteIP mistakenly returns an IPv6 (such as a ULA), clear it
|
||||
if ip := net.ParseIP(selfIP); ip != nil && ip.To4() == nil {
|
||||
mainLog.Load().Debug().Msgf("defaultRouteIP returned a non-IPv4 address: %s, ignoring it", selfIP)
|
||||
selfIP = ""
|
||||
}
|
||||
var ipv6 string
|
||||
|
||||
if delta.New.DefaultRouteInterface != "" {
|
||||
mainLog.Load().Debug().Msgf("default route interface: %s, IPs: %v", delta.New.DefaultRouteInterface, delta.New.InterfaceIPs[delta.New.DefaultRouteInterface])
|
||||
for _, ip := range delta.New.InterfaceIPs[delta.New.DefaultRouteInterface] {
|
||||
ipAddr, _ := netip.ParsePrefix(ip.String())
|
||||
addr := ipAddr.Addr()
|
||||
if selfIP == "" && addr.Is4() {
|
||||
mainLog.Load().Debug().Msgf("checking IP: %s", addr.String())
|
||||
if !addr.IsLoopback() && !addr.IsLinkLocalUnicast() {
|
||||
selfIP = addr.String()
|
||||
}
|
||||
}
|
||||
if addr.Is6() && !addr.IsLoopback() && !addr.IsLinkLocalUnicast() {
|
||||
ipv6 = addr.String()
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// If no default route interface is set yet, use the changed IPs
|
||||
mainLog.Load().Debug().Msgf("no default route interface found, using changed IPs: %v", changeIPs)
|
||||
for _, ip := range changeIPs {
|
||||
ipAddr, _ := netip.ParsePrefix(ip.String())
|
||||
addr := ipAddr.Addr()
|
||||
if selfIP == "" && addr.Is4() {
|
||||
mainLog.Load().Debug().Msgf("checking IP: %s", addr.String())
|
||||
if !addr.IsLoopback() && !addr.IsLinkLocalUnicast() {
|
||||
selfIP = addr.String()
|
||||
}
|
||||
}
|
||||
if addr.Is6() && !addr.IsLoopback() && !addr.IsLinkLocalUnicast() {
|
||||
ipv6 = addr.String()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Only set the IPv4 default if selfIP is a valid IPv4 address.
|
||||
if ip := net.ParseIP(selfIP); ip != nil && ip.To4() != nil {
|
||||
ctrld.SetDefaultLocalIPv4(ip)
|
||||
if !isMobile() && p.ciTable != nil {
|
||||
p.ciTable.SetSelfIP(selfIP)
|
||||
}
|
||||
}
|
||||
if ip := net.ParseIP(ipv6); ip != nil {
|
||||
ctrld.SetDefaultLocalIPv6(ip)
|
||||
}
|
||||
mainLog.Load().Debug().Msgf("Set default local IPv4: %s, IPv6: %s", selfIP, ipv6)
|
||||
|
||||
// we only trigger recovery flow for network changes on non router devices
|
||||
if router.Name() == "" {
|
||||
p.handleRecovery(RecoveryReasonNetworkChange)
|
||||
}
|
||||
})
|
||||
|
||||
mon.Start()
|
||||
mainLog.Load().Debug().Msg("Network monitor started")
|
||||
return nil
|
||||
}
|
||||
|
||||
// interfaceStatesEqual compares two interface states
|
||||
func interfaceStatesEqual(a, b *netmon.Interface) bool {
|
||||
if a == nil || b == nil {
|
||||
return a == b
|
||||
}
|
||||
return a.IsUp() == b.IsUp()
|
||||
}
|
||||
|
||||
// filterUsableIPs is a helper that returns only "usable" IP prefixes,
|
||||
// filtering out link-local, loopback, multicast, unspecified, broadcast, or CGNAT addresses.
|
||||
func filterUsableIPs(prefixes []netip.Prefix) []netip.Prefix {
|
||||
var usable []netip.Prefix
|
||||
for _, p := range prefixes {
|
||||
addr := p.Addr()
|
||||
if addr.IsLinkLocalUnicast() ||
|
||||
addr.IsLoopback() ||
|
||||
addr.IsMulticast() ||
|
||||
addr.IsUnspecified() ||
|
||||
addr.IsLinkLocalMulticast() ||
|
||||
(addr.Is4() && addr.String() == "255.255.255.255") ||
|
||||
tsaddr.CGNATRange().Contains(addr) {
|
||||
continue
|
||||
}
|
||||
usable = append(usable, p)
|
||||
}
|
||||
return usable
|
||||
}
|
||||
|
||||
// Modified interfaceIPsEqual compares only the usable (non-link local, non-loopback, etc.) IP addresses.
|
||||
func interfaceIPsEqual(a, b []netip.Prefix) bool {
|
||||
aUsable := filterUsableIPs(a)
|
||||
bUsable := filterUsableIPs(b)
|
||||
if len(aUsable) != len(bUsable) {
|
||||
return false
|
||||
}
|
||||
|
||||
aMap := make(map[string]bool)
|
||||
for _, ip := range aUsable {
|
||||
aMap[ip.String()] = true
|
||||
}
|
||||
for _, ip := range bUsable {
|
||||
if !aMap[ip.String()] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// checkUpstreamOnce sends a test query to the specified upstream.
|
||||
// Returns nil if the upstream responds successfully.
|
||||
func (p *prog) checkUpstreamOnce(upstream string, uc *ctrld.UpstreamConfig) error {
|
||||
mainLog.Load().Debug().Msgf("Starting check for upstream: %s", upstream)
|
||||
|
||||
resolver, err := ctrld.NewResolver(uc)
|
||||
if err != nil {
|
||||
mainLog.Load().Error().Err(err).Msgf("Failed to create resolver for upstream %s", upstream)
|
||||
return err
|
||||
}
|
||||
|
||||
timeout := 1000 * time.Millisecond
|
||||
if uc.Timeout > 0 {
|
||||
timeout = time.Millisecond * time.Duration(uc.Timeout)
|
||||
}
|
||||
mainLog.Load().Debug().Msgf("Timeout for upstream %s: %s", upstream, timeout)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
||||
defer cancel()
|
||||
|
||||
uc.ReBootstrap()
|
||||
mainLog.Load().Debug().Msgf("Rebootstrapping resolver for upstream: %s", upstream)
|
||||
|
||||
start := time.Now()
|
||||
msg := uc.VerifyMsg()
|
||||
_, err = resolver.Resolve(ctx, msg)
|
||||
duration := time.Since(start)
|
||||
|
||||
if err != nil {
|
||||
mainLog.Load().Error().Err(err).Msgf("Upstream %s check failed after %v", upstream, duration)
|
||||
} else {
|
||||
mainLog.Load().Debug().Msgf("Upstream %s responded successfully in %v", upstream, duration)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// handleRecovery performs a unified recovery by removing DNS settings,
|
||||
// canceling existing recovery checks for network changes, but coalescing duplicate
|
||||
// upstream failure recoveries, waiting for recovery to complete (using a cancellable context without timeout),
|
||||
// and then re-applying the DNS settings.
|
||||
func (p *prog) handleRecovery(reason RecoveryReason) {
|
||||
mainLog.Load().Debug().Msg("Starting recovery process: removing DNS settings")
|
||||
|
||||
// For network changes, cancel any existing recovery check because the network state has changed.
|
||||
if reason == RecoveryReasonNetworkChange {
|
||||
p.recoveryCancelMu.Lock()
|
||||
if p.recoveryCancel != nil {
|
||||
mainLog.Load().Debug().Msg("Cancelling existing recovery check (network change)")
|
||||
p.recoveryCancel()
|
||||
p.recoveryCancel = nil
|
||||
}
|
||||
p.recoveryCancelMu.Unlock()
|
||||
} else {
|
||||
// For upstream failures, if a recovery is already in progress, do nothing new.
|
||||
p.recoveryCancelMu.Lock()
|
||||
if p.recoveryCancel != nil {
|
||||
mainLog.Load().Debug().Msg("Upstream recovery already in progress; skipping duplicate trigger")
|
||||
p.recoveryCancelMu.Unlock()
|
||||
return
|
||||
}
|
||||
p.recoveryCancelMu.Unlock()
|
||||
}
|
||||
|
||||
// Create a new recovery context without a fixed timeout.
|
||||
p.recoveryCancelMu.Lock()
|
||||
recoveryCtx, cancel := context.WithCancel(context.Background())
|
||||
p.recoveryCancel = cancel
|
||||
p.recoveryCancelMu.Unlock()
|
||||
|
||||
// Immediately remove our DNS settings from the interface.
|
||||
// set recoveryRunning to true to prevent watchdogs from putting the listener back on the interface
|
||||
p.recoveryRunning.Store(true)
|
||||
// we do not want to restore any static DNS settings
|
||||
// we must try to get the DHCP values, any static DNS settings
|
||||
// will be appended to nameservers from the saved interface values
|
||||
p.resetDNS(false, false)
|
||||
|
||||
// For an OS failure, reinitialize OS resolver nameservers immediately.
|
||||
if reason == RecoveryReasonOSFailure {
|
||||
mainLog.Load().Debug().Msg("OS resolver failure detected; reinitializing OS resolver nameservers")
|
||||
ns := ctrld.InitializeOsResolver(true)
|
||||
if len(ns) == 0 {
|
||||
mainLog.Load().Warn().Msg("No nameservers found for OS resolver; using existing values")
|
||||
} else {
|
||||
mainLog.Load().Info().Msgf("Reinitialized OS resolver with nameservers: %v", ns)
|
||||
}
|
||||
}
|
||||
|
||||
// Build upstream map based on the recovery reason.
|
||||
upstreams := p.buildRecoveryUpstreams(reason)
|
||||
|
||||
// Wait indefinitely until one of the upstreams recovers.
|
||||
recovered, err := p.waitForUpstreamRecovery(recoveryCtx, upstreams)
|
||||
if err != nil {
|
||||
mainLog.Load().Error().Err(err).Msg("Recovery canceled; DNS settings remain removed")
|
||||
p.recoveryCancelMu.Lock()
|
||||
p.recoveryCancel = nil
|
||||
p.recoveryCancelMu.Unlock()
|
||||
return
|
||||
}
|
||||
mainLog.Load().Info().Msgf("Upstream %q recovered; re-applying DNS settings", recovered)
|
||||
|
||||
// reset the upstream failure count and down state
|
||||
p.um.reset(recovered)
|
||||
|
||||
// For network changes we also reinitialize the OS resolver.
|
||||
if reason == RecoveryReasonNetworkChange {
|
||||
ns := ctrld.InitializeOsResolver(true)
|
||||
if len(ns) == 0 {
|
||||
mainLog.Load().Warn().Msg("No nameservers found for OS resolver during network-change recovery; using existing values")
|
||||
} else {
|
||||
mainLog.Load().Info().Msgf("Reinitialized OS resolver with nameservers: %v", ns)
|
||||
}
|
||||
}
|
||||
|
||||
// Apply our DNS settings back and log the interface state.
|
||||
p.setDNS()
|
||||
p.logInterfacesState()
|
||||
|
||||
// allow watchdogs to put the listener back on the interface if its changed for any reason
|
||||
p.recoveryRunning.Store(false)
|
||||
|
||||
// Clear the recovery cancellation for a clean slate.
|
||||
p.recoveryCancelMu.Lock()
|
||||
p.recoveryCancel = nil
|
||||
p.recoveryCancelMu.Unlock()
|
||||
}
|
||||
|
||||
// waitForUpstreamRecovery checks the provided upstreams concurrently until one recovers.
|
||||
// It returns the name of the recovered upstream or an error if the check times out.
|
||||
func (p *prog) waitForUpstreamRecovery(ctx context.Context, upstreams map[string]*ctrld.UpstreamConfig) (string, error) {
|
||||
recoveredCh := make(chan string, 1)
|
||||
var wg sync.WaitGroup
|
||||
|
||||
mainLog.Load().Debug().Msgf("Starting upstream recovery check for %d upstreams", len(upstreams))
|
||||
|
||||
for name, uc := range upstreams {
|
||||
wg.Add(1)
|
||||
go func(name string, uc *ctrld.UpstreamConfig) {
|
||||
defer wg.Done()
|
||||
mainLog.Load().Debug().Msgf("Starting recovery check loop for upstream: %s", name)
|
||||
attempts := 0
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
mainLog.Load().Debug().Msgf("Context canceled for upstream %s", name)
|
||||
return
|
||||
default:
|
||||
attempts++
|
||||
// checkUpstreamOnce will reset any failure counters on success.
|
||||
if err := p.checkUpstreamOnce(name, uc); err == nil {
|
||||
mainLog.Load().Debug().Msgf("Upstream %s recovered successfully", name)
|
||||
select {
|
||||
case recoveredCh <- name:
|
||||
mainLog.Load().Debug().Msgf("Sent recovery notification for upstream %s", name)
|
||||
default:
|
||||
mainLog.Load().Debug().Msg("Recovery channel full, another upstream already recovered")
|
||||
}
|
||||
return
|
||||
}
|
||||
mainLog.Load().Debug().Msgf("Upstream %s check failed, sleeping before retry", name)
|
||||
time.Sleep(checkUpstreamBackoffSleep)
|
||||
|
||||
// if this is the upstreamOS and it's the 3rd attempt (or multiple of 3),
|
||||
// we should try to reinit the OS resolver to ensure we can recover
|
||||
if name == upstreamOS && attempts%3 == 0 {
|
||||
mainLog.Load().Debug().Msgf("UpstreamOS check failed on attempt %d, reinitializing OS resolver", attempts)
|
||||
ns := ctrld.InitializeOsResolver(true)
|
||||
if len(ns) == 0 {
|
||||
mainLog.Load().Warn().Msg("No nameservers found for OS resolver; using existing values")
|
||||
} else {
|
||||
mainLog.Load().Info().Msgf("Reinitialized OS resolver with nameservers: %v", ns)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}(name, uc)
|
||||
}
|
||||
|
||||
var recovered string
|
||||
select {
|
||||
case recovered = <-recoveredCh:
|
||||
case <-ctx.Done():
|
||||
return "", ctx.Err()
|
||||
}
|
||||
wg.Wait()
|
||||
return recovered, nil
|
||||
}
|
||||
|
||||
// buildRecoveryUpstreams constructs the map of upstream configurations to test.
|
||||
// For OS failures we supply the manual OS resolver upstream configuration.
|
||||
// For network change or regular failure we use the upstreams defined in p.cfg (ignoring OS).
|
||||
func (p *prog) buildRecoveryUpstreams(reason RecoveryReason) map[string]*ctrld.UpstreamConfig {
|
||||
upstreams := make(map[string]*ctrld.UpstreamConfig)
|
||||
switch reason {
|
||||
case RecoveryReasonOSFailure:
|
||||
upstreams[upstreamOS] = osUpstreamConfig
|
||||
case RecoveryReasonNetworkChange, RecoveryReasonRegularFailure:
|
||||
// Use all configured upstreams except any OS type.
|
||||
for k, uc := range p.cfg.Upstream {
|
||||
if uc.Type != ctrld.ResolverTypeOS {
|
||||
upstreams[upstreamPrefix+k] = uc
|
||||
}
|
||||
}
|
||||
}
|
||||
return upstreams
|
||||
}
|
||||
|
||||
// ValidateDefaultLocalIPsFromDelta checks if the default local IPv4 and IPv6 stored
|
||||
// are still present in the new network state (provided by delta.New).
|
||||
// If a stored default IP is no longer active, it resets that default (sets it to nil)
|
||||
// so that it won't be used in subsequent custom dialer contexts.
|
||||
func ValidateDefaultLocalIPsFromDelta(newState *netmon.State) {
|
||||
currentIPv4 := ctrld.GetDefaultLocalIPv4()
|
||||
currentIPv6 := ctrld.GetDefaultLocalIPv6()
|
||||
|
||||
// Build a map of active IP addresses from the new state.
|
||||
activeIPs := make(map[string]bool)
|
||||
for _, prefixes := range newState.InterfaceIPs {
|
||||
for _, prefix := range prefixes {
|
||||
activeIPs[prefix.Addr().String()] = true
|
||||
}
|
||||
}
|
||||
|
||||
// Check if the default IPv4 is still active.
|
||||
if currentIPv4 != nil && !activeIPs[currentIPv4.String()] {
|
||||
mainLog.Load().Debug().Msgf("DefaultLocalIPv4 %s is no longer active in the new state. Resetting.", currentIPv4)
|
||||
ctrld.SetDefaultLocalIPv4(nil)
|
||||
}
|
||||
|
||||
// Check if the default IPv6 is still active.
|
||||
if currentIPv6 != nil && !activeIPs[currentIPv6.String()] {
|
||||
mainLog.Load().Debug().Msgf("DefaultLocalIPv6 %s is no longer active in the new state. Resetting.", currentIPv6)
|
||||
ctrld.SetDefaultLocalIPv6(nil)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -75,6 +75,7 @@ func Test_canonicalName(t *testing.T) {
|
||||
|
||||
func Test_prog_upstreamFor(t *testing.T) {
|
||||
cfg := testhelper.SampleConfig(t)
|
||||
cfg.Service.LeakOnUpstreamFailure = func(v bool) *bool { return &v }(false)
|
||||
p := &prog{cfg: cfg}
|
||||
p.um = newUpstreamMonitor(p.cfg)
|
||||
p.lanLoopGuard = newLoopGuard()
|
||||
@@ -365,6 +366,9 @@ func Test_isLanHostnameQuery(t *testing.T) {
|
||||
{"A not LAN", newDnsMsgWithHostname("example.com", dns.TypeA), false},
|
||||
{"AAAA not LAN", newDnsMsgWithHostname("example.com", dns.TypeAAAA), false},
|
||||
{"Not A or AAAA", newDnsMsgWithHostname("foo", dns.TypeTXT), false},
|
||||
{".domain", newDnsMsgWithHostname("foo.domain", dns.TypeA), true},
|
||||
{".lan", newDnsMsgWithHostname("foo.lan", dns.TypeA), true},
|
||||
{".local", newDnsMsgWithHostname("foo.local", dns.TypeA), true},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
@@ -414,6 +418,27 @@ func Test_isPrivatePtrLookup(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func Test_isSrvLanLookup(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
msg *dns.Msg
|
||||
isSrvLookup bool
|
||||
}{
|
||||
{"SRV LAN", newDnsMsgWithHostname("foo", dns.TypeSRV), true},
|
||||
{"Not SRV", newDnsMsgWithHostname("foo", dns.TypeNone), false},
|
||||
{"Not SRV LAN", newDnsMsgWithHostname("controld.com", dns.TypeSRV), false},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
if got := isSrvLanLookup(tc.msg); tc.isSrvLookup != got {
|
||||
t.Errorf("unexpected result, want: %v, got: %v", tc.isSrvLookup, got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_isWanClient(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
|
||||
@@ -1,5 +1,12 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
// AppCallback provides hooks for injecting certain functionalities
|
||||
// from mobile platforms to main ctrld cli.
|
||||
type AppCallback struct {
|
||||
@@ -11,9 +18,78 @@ type AppCallback struct {
|
||||
|
||||
// AppConfig allows overwriting ctrld cli flags from mobile platforms.
|
||||
type AppConfig struct {
|
||||
CdUID string
|
||||
HomeDir string
|
||||
UpstreamProto string
|
||||
Verbose int
|
||||
LogPath string
|
||||
CdUID string
|
||||
ProvisionID string
|
||||
CustomHostname string
|
||||
HomeDir string
|
||||
UpstreamProto string
|
||||
Verbose int
|
||||
LogPath string
|
||||
}
|
||||
|
||||
const (
|
||||
defaultHTTPTimeout = 30 * time.Second
|
||||
defaultMaxRetries = 3
|
||||
downloadServerIp = "23.171.240.151"
|
||||
)
|
||||
|
||||
// httpClientWithFallback returns an HTTP client configured with timeout and IPv4 fallback
|
||||
func httpClientWithFallback(timeout time.Duration) *http.Client {
|
||||
return &http.Client{
|
||||
Timeout: timeout,
|
||||
Transport: &http.Transport{
|
||||
// Prefer IPv4 over IPv6
|
||||
DialContext: (&net.Dialer{
|
||||
Timeout: 10 * time.Second,
|
||||
KeepAlive: 30 * time.Second,
|
||||
FallbackDelay: 1 * time.Millisecond, // Very small delay to prefer IPv4
|
||||
}).DialContext,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// doWithRetry performs an HTTP request with retries
|
||||
func doWithRetry(req *http.Request, maxRetries int, ip string) (*http.Response, error) {
|
||||
var lastErr error
|
||||
client := httpClientWithFallback(defaultHTTPTimeout)
|
||||
var ipReq *http.Request
|
||||
if ip != "" {
|
||||
ipReq = req.Clone(req.Context())
|
||||
ipReq.Host = ip
|
||||
ipReq.URL.Host = ip
|
||||
}
|
||||
for attempt := 0; attempt < maxRetries; attempt++ {
|
||||
if attempt > 0 {
|
||||
time.Sleep(time.Second * time.Duration(attempt+1)) // Exponential backoff
|
||||
}
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err == nil {
|
||||
return resp, nil
|
||||
}
|
||||
if ipReq != nil {
|
||||
mainLog.Load().Warn().Err(err).Msgf("dial to %q failed", req.Host)
|
||||
mainLog.Load().Warn().Msgf("fallback to direct IP to download prod version: %q", ip)
|
||||
resp, err = client.Do(ipReq)
|
||||
if err == nil {
|
||||
return resp, nil
|
||||
}
|
||||
}
|
||||
|
||||
lastErr = err
|
||||
mainLog.Load().Debug().Err(err).
|
||||
Str("method", req.Method).
|
||||
Str("url", req.URL.String()).
|
||||
Msgf("HTTP request attempt %d/%d failed", attempt+1, maxRetries)
|
||||
}
|
||||
return nil, fmt.Errorf("failed after %d attempts to %s %s: %v", maxRetries, req.Method, req.URL, lastErr)
|
||||
}
|
||||
|
||||
// Helper for making GET requests with retries
|
||||
func getWithRetry(url string, ip string) (*http.Response, error) {
|
||||
req, err := http.NewRequest(http.MethodGet, url, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return doWithRetry(req, defaultMaxRetries, ip)
|
||||
}
|
||||
|
||||
204
cmd/cli/log_writer.go
Normal file
204
cmd/cli/log_writer.go
Normal file
@@ -0,0 +1,204 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld"
|
||||
)
|
||||
|
||||
const (
|
||||
logWriterSize = 1024 * 1024 * 5 // 5 MB
|
||||
logWriterSmallSize = 1024 * 1024 * 1 // 1 MB
|
||||
logWriterInitialSize = 32 * 1024 // 32 KB
|
||||
logWriterSentInterval = time.Minute
|
||||
logWriterInitEndMarker = "\n\n=== INIT_END ===\n\n"
|
||||
logWriterLogEndMarker = "\n\n=== LOG_END ===\n\n"
|
||||
)
|
||||
|
||||
type logViewResponse struct {
|
||||
Data string `json:"data"`
|
||||
}
|
||||
|
||||
type logSentResponse struct {
|
||||
Size int64 `json:"size"`
|
||||
Error string `json:"error"`
|
||||
}
|
||||
|
||||
type logReader struct {
|
||||
r io.ReadCloser
|
||||
size int64
|
||||
}
|
||||
|
||||
// logWriter is an internal buffer to keep track of runtime log when no logging is enabled.
|
||||
type logWriter struct {
|
||||
mu sync.Mutex
|
||||
buf bytes.Buffer
|
||||
size int
|
||||
}
|
||||
|
||||
// newLogWriter creates an internal log writer.
|
||||
func newLogWriter() *logWriter {
|
||||
return newLogWriterWithSize(logWriterSize)
|
||||
}
|
||||
|
||||
// newSmallLogWriter creates an internal log writer with small buffer size.
|
||||
func newSmallLogWriter() *logWriter {
|
||||
return newLogWriterWithSize(logWriterSmallSize)
|
||||
}
|
||||
|
||||
// newLogWriterWithSize creates an internal log writer with a given buffer size.
|
||||
func newLogWriterWithSize(size int) *logWriter {
|
||||
lw := &logWriter{size: size}
|
||||
return lw
|
||||
}
|
||||
|
||||
func (lw *logWriter) Write(p []byte) (int, error) {
|
||||
lw.mu.Lock()
|
||||
defer lw.mu.Unlock()
|
||||
|
||||
// If writing p causes overflows, discard old data.
|
||||
if lw.buf.Len()+len(p) > lw.size {
|
||||
buf := lw.buf.Bytes()
|
||||
haveEndMarker := false
|
||||
// If there's init end marker already, preserve the data til the marker.
|
||||
if idx := bytes.LastIndex(buf, []byte(logWriterInitEndMarker)); idx >= 0 {
|
||||
buf = buf[:idx+len(logWriterInitEndMarker)]
|
||||
haveEndMarker = true
|
||||
} else {
|
||||
// Otherwise, preserve the initial size data.
|
||||
buf = buf[:logWriterInitialSize]
|
||||
if idx := bytes.LastIndex(buf, []byte("\n")); idx != -1 {
|
||||
buf = buf[:idx]
|
||||
}
|
||||
}
|
||||
lw.buf.Reset()
|
||||
lw.buf.Write(buf)
|
||||
if !haveEndMarker {
|
||||
lw.buf.WriteString(logWriterInitEndMarker) // indicate that the log was truncated.
|
||||
}
|
||||
}
|
||||
// If p is bigger than buffer size, truncate p by half until its size is smaller.
|
||||
for len(p)+lw.buf.Len() > lw.size {
|
||||
p = p[len(p)/2:]
|
||||
}
|
||||
return lw.buf.Write(p)
|
||||
}
|
||||
|
||||
// initLogging initializes global logging setup.
|
||||
func (p *prog) initLogging(backup bool) {
|
||||
zerolog.TimeFieldFormat = time.RFC3339 + ".000"
|
||||
logWriters := initLoggingWithBackup(backup)
|
||||
|
||||
// Initializing internal logging after global logging.
|
||||
p.initInternalLogging(logWriters)
|
||||
}
|
||||
|
||||
// initInternalLogging performs internal logging if there's no log enabled.
|
||||
func (p *prog) initInternalLogging(writers []io.Writer) {
|
||||
if !p.needInternalLogging() {
|
||||
return
|
||||
}
|
||||
p.initInternalLogWriterOnce.Do(func() {
|
||||
mainLog.Load().Notice().Msg("internal logging enabled")
|
||||
p.internalLogWriter = newLogWriter()
|
||||
p.internalLogSent = time.Now().Add(-logWriterSentInterval)
|
||||
p.internalWarnLogWriter = newSmallLogWriter()
|
||||
})
|
||||
p.mu.Lock()
|
||||
lw := p.internalLogWriter
|
||||
wlw := p.internalWarnLogWriter
|
||||
p.mu.Unlock()
|
||||
// If ctrld was run without explicit verbose level,
|
||||
// run the internal logging at debug level, so we could
|
||||
// have enough information for troubleshooting.
|
||||
if verbose == 0 {
|
||||
for i := range writers {
|
||||
w := &zerolog.FilteredLevelWriter{
|
||||
Writer: zerolog.LevelWriterAdapter{Writer: writers[i]},
|
||||
Level: zerolog.NoticeLevel,
|
||||
}
|
||||
writers[i] = w
|
||||
}
|
||||
zerolog.SetGlobalLevel(zerolog.DebugLevel)
|
||||
}
|
||||
writers = append(writers, lw)
|
||||
writers = append(writers, &zerolog.FilteredLevelWriter{
|
||||
Writer: zerolog.LevelWriterAdapter{Writer: wlw},
|
||||
Level: zerolog.WarnLevel,
|
||||
})
|
||||
multi := zerolog.MultiLevelWriter(writers...)
|
||||
l := mainLog.Load().Output(multi).With().Logger()
|
||||
mainLog.Store(&l)
|
||||
ctrld.ProxyLogger.Store(&l)
|
||||
}
|
||||
|
||||
// needInternalLogging reports whether prog needs to run internal logging.
|
||||
func (p *prog) needInternalLogging() bool {
|
||||
// Do not run in non-cd mode.
|
||||
if cdUID == "" {
|
||||
return false
|
||||
}
|
||||
// Do not run if there's already log file.
|
||||
if p.cfg.Service.LogPath != "" {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (p *prog) logReader() (*logReader, error) {
|
||||
if p.needInternalLogging() {
|
||||
p.mu.Lock()
|
||||
lw := p.internalLogWriter
|
||||
wlw := p.internalWarnLogWriter
|
||||
p.mu.Unlock()
|
||||
if lw == nil {
|
||||
return nil, errors.New("nil internal log writer")
|
||||
}
|
||||
if wlw == nil {
|
||||
return nil, errors.New("nil internal warn log writer")
|
||||
}
|
||||
// Normal log content.
|
||||
lw.mu.Lock()
|
||||
lwReader := bytes.NewReader(lw.buf.Bytes())
|
||||
lwSize := lw.buf.Len()
|
||||
lw.mu.Unlock()
|
||||
// Warn log content.
|
||||
wlw.mu.Lock()
|
||||
wlwReader := bytes.NewReader(wlw.buf.Bytes())
|
||||
wlwSize := wlw.buf.Len()
|
||||
wlw.mu.Unlock()
|
||||
reader := io.MultiReader(lwReader, bytes.NewReader([]byte(logWriterLogEndMarker)), wlwReader)
|
||||
lr := &logReader{r: io.NopCloser(reader)}
|
||||
lr.size = int64(lwSize + wlwSize)
|
||||
if lr.size == 0 {
|
||||
return nil, errors.New("internal log is empty")
|
||||
}
|
||||
return lr, nil
|
||||
}
|
||||
if p.cfg.Service.LogPath == "" {
|
||||
return &logReader{r: io.NopCloser(strings.NewReader(""))}, nil
|
||||
}
|
||||
f, err := os.Open(normalizeLogFilePath(p.cfg.Service.LogPath))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
lr := &logReader{r: f}
|
||||
if st, err := f.Stat(); err == nil {
|
||||
lr.size = st.Size()
|
||||
} else {
|
||||
return nil, fmt.Errorf("f.Stat: %w", err)
|
||||
}
|
||||
if lr.size == 0 {
|
||||
return nil, errors.New("log file is empty")
|
||||
}
|
||||
return lr, nil
|
||||
}
|
||||
85
cmd/cli/log_writer_test.go
Normal file
85
cmd/cli/log_writer_test.go
Normal file
@@ -0,0 +1,85 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func Test_logWriter_Write(t *testing.T) {
|
||||
size := 64 * 1024
|
||||
lw := &logWriter{size: size}
|
||||
lw.buf.Grow(lw.size)
|
||||
data := strings.Repeat("A", size)
|
||||
lw.Write([]byte(data))
|
||||
if lw.buf.String() != data {
|
||||
t.Fatalf("unexpected buf content: %v", lw.buf.String())
|
||||
}
|
||||
newData := "B"
|
||||
halfData := strings.Repeat("A", len(data)/2) + logWriterInitEndMarker
|
||||
lw.Write([]byte(newData))
|
||||
if lw.buf.String() != halfData+newData {
|
||||
t.Fatalf("unexpected new buf content: %v", lw.buf.String())
|
||||
}
|
||||
|
||||
bigData := strings.Repeat("B", 256*1024)
|
||||
expected := halfData + strings.Repeat("B", 16*1024)
|
||||
lw.Write([]byte(bigData))
|
||||
if lw.buf.String() != expected {
|
||||
t.Fatalf("unexpected big buf content: %v", lw.buf.String())
|
||||
}
|
||||
}
|
||||
|
||||
func Test_logWriter_ConcurrentWrite(t *testing.T) {
|
||||
size := 64 * 1024
|
||||
lw := &logWriter{size: size}
|
||||
n := 10
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(n)
|
||||
for i := 0; i < n; i++ {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
lw.Write([]byte(strings.Repeat("A", i)))
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
if lw.buf.Len() > lw.size {
|
||||
t.Fatalf("unexpected buf size: %v, content: %q", lw.buf.Len(), lw.buf.String())
|
||||
}
|
||||
}
|
||||
|
||||
func Test_logWriter_MarkerInitEnd(t *testing.T) {
|
||||
size := 64 * 1024
|
||||
lw := &logWriter{size: size}
|
||||
lw.buf.Grow(lw.size)
|
||||
|
||||
paddingSize := 10
|
||||
// Writing half of the size, minus len(end marker) and padding size.
|
||||
dataSize := size/2 - len(logWriterInitEndMarker) - paddingSize
|
||||
data := strings.Repeat("A", dataSize)
|
||||
// Inserting newline for making partial init data
|
||||
data += "\n"
|
||||
// Filling left over buffer to make the log full.
|
||||
// The data length: len(end marker) + padding size - 1 (for newline above) + size/2
|
||||
data += strings.Repeat("A", len(logWriterInitEndMarker)+paddingSize-1+(size/2))
|
||||
lw.Write([]byte(data))
|
||||
if lw.buf.String() != data {
|
||||
t.Fatalf("unexpected buf content: %v", lw.buf.String())
|
||||
}
|
||||
lw.Write([]byte("B"))
|
||||
lw.Write([]byte(strings.Repeat("B", 256*1024)))
|
||||
firstIdx := strings.Index(lw.buf.String(), logWriterInitEndMarker)
|
||||
lastIdx := strings.LastIndex(lw.buf.String(), logWriterInitEndMarker)
|
||||
// Check if init end marker present.
|
||||
if firstIdx == -1 || lastIdx == -1 {
|
||||
t.Fatalf("missing init end marker: %s", lw.buf.String())
|
||||
}
|
||||
// Check if init end marker appears only once.
|
||||
if firstIdx != lastIdx {
|
||||
t.Fatalf("log init end marker appears more than once: %s", lw.buf.String())
|
||||
}
|
||||
// Ensure that we have the correct init log data.
|
||||
if !strings.Contains(lw.buf.String(), strings.Repeat("A", dataSize)+logWriterInitEndMarker) {
|
||||
t.Fatalf("unexpected log content: %s", lw.buf.String())
|
||||
}
|
||||
}
|
||||
@@ -39,6 +39,7 @@ var (
|
||||
skipSelfChecks bool
|
||||
cleanup bool
|
||||
startOnly bool
|
||||
rfc1918 bool
|
||||
|
||||
mainLog atomic.Pointer[zerolog.Logger]
|
||||
consoleWriter zerolog.ConsoleWriter
|
||||
@@ -88,22 +89,33 @@ func initConsoleLogging() {
|
||||
multi := zerolog.MultiLevelWriter(consoleWriter)
|
||||
l := mainLog.Load().Output(multi).With().Timestamp().Logger()
|
||||
mainLog.Store(&l)
|
||||
|
||||
switch {
|
||||
case silent:
|
||||
zerolog.SetGlobalLevel(zerolog.NoLevel)
|
||||
case verbose == 1:
|
||||
ctrld.ProxyLogger.Store(&l)
|
||||
zerolog.SetGlobalLevel(zerolog.InfoLevel)
|
||||
case verbose > 1:
|
||||
ctrld.ProxyLogger.Store(&l)
|
||||
zerolog.SetGlobalLevel(zerolog.DebugLevel)
|
||||
default:
|
||||
zerolog.SetGlobalLevel(zerolog.NoticeLevel)
|
||||
}
|
||||
}
|
||||
|
||||
// initLogging initializes global logging setup.
|
||||
func initLogging() {
|
||||
// initInteractiveLogging is like initLogging, but the ProxyLogger is discarded
|
||||
// to be used for all interactive commands.
|
||||
//
|
||||
// Current log file config will also be ignored.
|
||||
func initInteractiveLogging() {
|
||||
old := cfg.Service.LogPath
|
||||
cfg.Service.LogPath = ""
|
||||
zerolog.TimeFieldFormat = time.RFC3339 + ".000"
|
||||
initLoggingWithBackup(true)
|
||||
initLoggingWithBackup(false)
|
||||
cfg.Service.LogPath = old
|
||||
l := zerolog.New(io.Discard)
|
||||
ctrld.ProxyLogger.Store(&l)
|
||||
}
|
||||
|
||||
// initLoggingWithBackup initializes log setup base on current config.
|
||||
@@ -112,8 +124,8 @@ func initLogging() {
|
||||
// This is only used in runCmd for special handling in case of logging config
|
||||
// change in cd mode. Without special reason, the caller should use initLogging
|
||||
// wrapper instead of calling this function directly.
|
||||
func initLoggingWithBackup(doBackup bool) {
|
||||
writers := []io.Writer{io.Discard}
|
||||
func initLoggingWithBackup(doBackup bool) []io.Writer {
|
||||
var writers []io.Writer
|
||||
if logFilePath := normalizeLogFilePath(cfg.Service.LogPath); logFilePath != "" {
|
||||
// Create parent directory if necessary.
|
||||
if err := os.MkdirAll(filepath.Dir(logFilePath), 0750); err != nil {
|
||||
@@ -151,21 +163,22 @@ func initLoggingWithBackup(doBackup bool) {
|
||||
switch {
|
||||
case silent:
|
||||
zerolog.SetGlobalLevel(zerolog.NoLevel)
|
||||
return
|
||||
return writers
|
||||
case verbose == 1:
|
||||
logLevel = "info"
|
||||
case verbose > 1:
|
||||
logLevel = "debug"
|
||||
}
|
||||
if logLevel == "" {
|
||||
return
|
||||
return writers
|
||||
}
|
||||
level, err := zerolog.ParseLevel(logLevel)
|
||||
if err != nil {
|
||||
mainLog.Load().Warn().Err(err).Msg("could not set log level")
|
||||
return
|
||||
return writers
|
||||
}
|
||||
zerolog.SetGlobalLevel(level)
|
||||
return writers
|
||||
}
|
||||
|
||||
func initCache() {
|
||||
|
||||
@@ -9,17 +9,18 @@ import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
func patchNetIfaceName(iface *net.Interface) error {
|
||||
func patchNetIfaceName(iface *net.Interface) (bool, error) {
|
||||
b, err := exec.Command("networksetup", "-listnetworkserviceorder").Output()
|
||||
if err != nil {
|
||||
return err
|
||||
return false, err
|
||||
}
|
||||
|
||||
patched := false
|
||||
if name := networkServiceName(iface.Name, bytes.NewReader(b)); name != "" {
|
||||
patched = true
|
||||
iface.Name = name
|
||||
mainLog.Load().Debug().Str("network_service", name).Msg("found network service name for interface")
|
||||
}
|
||||
return nil
|
||||
return patched, nil
|
||||
}
|
||||
|
||||
func networkServiceName(ifaceName string, r io.Reader) string {
|
||||
|
||||
52
cmd/cli/net_linux.go
Normal file
52
cmd/cli/net_linux.go
Normal file
@@ -0,0 +1,52 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"tailscale.com/net/netmon"
|
||||
)
|
||||
|
||||
func patchNetIfaceName(iface *net.Interface) (bool, error) { return true, nil }
|
||||
|
||||
// validInterface reports whether the *net.Interface is a valid one.
|
||||
// Only non-virtual interfaces are considered valid.
|
||||
func validInterface(iface *net.Interface, validIfacesMap map[string]struct{}) bool {
|
||||
_, ok := validIfacesMap[iface.Name]
|
||||
return ok
|
||||
}
|
||||
|
||||
// validInterfacesMap returns a set containing non virtual interfaces.
|
||||
func validInterfacesMap() map[string]struct{} {
|
||||
m := make(map[string]struct{})
|
||||
vis := virtualInterfaces()
|
||||
netmon.ForeachInterface(func(i netmon.Interface, prefixes []netip.Prefix) {
|
||||
if _, existed := vis[i.Name]; existed {
|
||||
return
|
||||
}
|
||||
m[i.Name] = struct{}{}
|
||||
})
|
||||
// Fallback to default route interface if found nothing.
|
||||
if len(m) == 0 {
|
||||
defaultRoute, err := netmon.DefaultRoute()
|
||||
if err != nil {
|
||||
return m
|
||||
}
|
||||
m[defaultRoute.InterfaceName] = struct{}{}
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
// virtualInterfaces returns a map of virtual interfaces on current machine.
|
||||
func virtualInterfaces() map[string]struct{} {
|
||||
s := make(map[string]struct{})
|
||||
entries, _ := os.ReadDir("/sys/devices/virtual/net")
|
||||
for _, entry := range entries {
|
||||
if entry.IsDir() {
|
||||
s[strings.TrimSpace(entry.Name())] = struct{}{}
|
||||
}
|
||||
}
|
||||
return s
|
||||
}
|
||||
@@ -1,11 +1,22 @@
|
||||
//go:build !darwin && !windows
|
||||
//go:build !darwin && !windows && !linux
|
||||
|
||||
package cli
|
||||
|
||||
import "net"
|
||||
import (
|
||||
"net"
|
||||
|
||||
func patchNetIfaceName(iface *net.Interface) error { return nil }
|
||||
"tailscale.com/net/netmon"
|
||||
)
|
||||
|
||||
func patchNetIfaceName(iface *net.Interface) (bool, error) { return true, nil }
|
||||
|
||||
func validInterface(iface *net.Interface, validIfacesMap map[string]struct{}) bool { return true }
|
||||
|
||||
func validInterfacesMap() map[string]struct{} { return nil }
|
||||
// validInterfacesMap returns a set containing only default route interfaces.
|
||||
func validInterfacesMap() map[string]struct{} {
|
||||
defaultRoute, err := netmon.DefaultRoute()
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
return map[string]struct{}{defaultRoute.InterfaceName: {}}
|
||||
}
|
||||
|
||||
@@ -1,14 +1,20 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"strings"
|
||||
"os"
|
||||
|
||||
"github.com/microsoft/wmi/pkg/base/host"
|
||||
"github.com/microsoft/wmi/pkg/base/instance"
|
||||
"github.com/microsoft/wmi/pkg/base/query"
|
||||
"github.com/microsoft/wmi/pkg/constant"
|
||||
"github.com/microsoft/wmi/pkg/hardware/network/netadapter"
|
||||
)
|
||||
|
||||
func patchNetIfaceName(iface *net.Interface) error {
|
||||
return nil
|
||||
func patchNetIfaceName(iface *net.Interface) (bool, error) {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// validInterface reports whether the *net.Interface is a valid one.
|
||||
@@ -20,15 +26,68 @@ func validInterface(iface *net.Interface, validIfacesMap map[string]struct{}) bo
|
||||
|
||||
// validInterfacesMap returns a set of all physical interfaces.
|
||||
func validInterfacesMap() map[string]struct{} {
|
||||
out, err := powershell("Get-NetAdapter -Physical | Select-Object -ExpandProperty Name")
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
m := make(map[string]struct{})
|
||||
scanner := bufio.NewScanner(bytes.NewReader(out))
|
||||
for scanner.Scan() {
|
||||
ifaceName := strings.TrimSpace(scanner.Text())
|
||||
for _, ifaceName := range validInterfaces() {
|
||||
m[ifaceName] = struct{}{}
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
// validInterfaces returns a list of all physical interfaces.
|
||||
func validInterfaces() []string {
|
||||
log.SetOutput(io.Discard)
|
||||
defer log.SetOutput(os.Stderr)
|
||||
whost := host.NewWmiLocalHost()
|
||||
q := query.NewWmiQuery("MSFT_NetAdapter")
|
||||
instances, err := instance.GetWmiInstancesFromHost(whost, string(constant.StadardCimV2), q)
|
||||
if instances != nil {
|
||||
defer instances.Close()
|
||||
}
|
||||
if err != nil {
|
||||
mainLog.Load().Warn().Err(err).Msg("failed to get wmi network adapter")
|
||||
return nil
|
||||
}
|
||||
var adapters []string
|
||||
for _, i := range instances {
|
||||
adapter, err := netadapter.NewNetworkAdapter(i)
|
||||
if err != nil {
|
||||
mainLog.Load().Warn().Err(err).Msg("failed to get network adapter")
|
||||
continue
|
||||
}
|
||||
|
||||
name, err := adapter.GetPropertyName()
|
||||
if err != nil {
|
||||
mainLog.Load().Warn().Err(err).Msg("failed to get interface name")
|
||||
continue
|
||||
}
|
||||
|
||||
// From: https://learn.microsoft.com/en-us/previous-versions/windows/desktop/legacy/hh968170(v=vs.85)
|
||||
//
|
||||
// "Indicates if a connector is present on the network adapter. This value is set to TRUE
|
||||
// if this is a physical adapter or FALSE if this is not a physical adapter."
|
||||
physical, err := adapter.GetPropertyConnectorPresent()
|
||||
if err != nil {
|
||||
mainLog.Load().Debug().Str("method", "validInterfaces").Str("interface", name).Msg("failed to get network adapter connector present property")
|
||||
continue
|
||||
}
|
||||
if !physical {
|
||||
mainLog.Load().Debug().Str("method", "validInterfaces").Str("interface", name).Msg("skipping non-physical adapter")
|
||||
continue
|
||||
}
|
||||
|
||||
// Check if it's a hardware interface. Checking only for connector present is not enough
|
||||
// because some interfaces are not physical but have a connector.
|
||||
hardware, err := adapter.GetPropertyHardwareInterface()
|
||||
if err != nil {
|
||||
mainLog.Load().Debug().Str("method", "validInterfaces").Str("interface", name).Msg("failed to get network adapter hardware interface property")
|
||||
continue
|
||||
}
|
||||
if !hardware {
|
||||
mainLog.Load().Debug().Str("method", "validInterfaces").Str("interface", name).Msg("skipping non-hardware interface")
|
||||
continue
|
||||
}
|
||||
|
||||
adapters = append(adapters, name)
|
||||
}
|
||||
return adapters
|
||||
}
|
||||
|
||||
42
cmd/cli/net_windows_test.go
Normal file
42
cmd/cli/net_windows_test.go
Normal file
@@ -0,0 +1,42 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"slices"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func Test_validInterfaces(t *testing.T) {
|
||||
verbose = 3
|
||||
initConsoleLogging()
|
||||
start := time.Now()
|
||||
ifaces := validInterfaces()
|
||||
t.Logf("Using Windows API takes: %d", time.Since(start).Milliseconds())
|
||||
|
||||
start = time.Now()
|
||||
ifacesPowershell := validInterfacesPowershell()
|
||||
t.Logf("Using Powershell takes: %d", time.Since(start).Milliseconds())
|
||||
|
||||
slices.Sort(ifaces)
|
||||
slices.Sort(ifacesPowershell)
|
||||
if !slices.Equal(ifaces, ifacesPowershell) {
|
||||
t.Fatalf("result mismatch, want: %v, got: %v", ifacesPowershell, ifaces)
|
||||
}
|
||||
}
|
||||
|
||||
func validInterfacesPowershell() []string {
|
||||
out, err := powershell("Get-NetAdapter -Physical | Select-Object -ExpandProperty Name")
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
var res []string
|
||||
scanner := bufio.NewScanner(bytes.NewReader(out))
|
||||
for scanner.Scan() {
|
||||
ifaceName := strings.TrimSpace(scanner.Text())
|
||||
res = append(res, ifaceName)
|
||||
}
|
||||
return res
|
||||
}
|
||||
5
cmd/cli/nocgo.go
Normal file
5
cmd/cli/nocgo.go
Normal file
@@ -0,0 +1,5 @@
|
||||
//go:build !cgo
|
||||
|
||||
package cli
|
||||
|
||||
const cgoEnabled = false
|
||||
@@ -47,6 +47,9 @@ func setDnsIgnoreUnusableInterface(iface *net.Interface, nameservers []string) e
|
||||
// networksetup -setdnsservers Wi-Fi 8.8.8.8 1.1.1.1
|
||||
// TODO(cuonglm): use system API
|
||||
func setDNS(iface *net.Interface, nameservers []string) error {
|
||||
// Note that networksetup won't modify search domains settings,
|
||||
// This assignment is just a placeholder to silent linter.
|
||||
_ = searchDomains
|
||||
cmd := "networksetup"
|
||||
args := []string{"-setdnsservers", iface.Name}
|
||||
args = append(args, nameservers...)
|
||||
@@ -70,11 +73,6 @@ func resetDnsIgnoreUnusableInterface(iface *net.Interface) error {
|
||||
|
||||
// TODO(cuonglm): use system API
|
||||
func resetDNS(iface *net.Interface) error {
|
||||
if ns := savedStaticNameservers(iface); len(ns) > 0 {
|
||||
if err := setDNS(iface, ns); err == nil {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
cmd := "networksetup"
|
||||
args := []string{"-setdnsservers", iface.Name, "empty"}
|
||||
if out, err := exec.Command(cmd, args...).CombinedOutput(); err != nil {
|
||||
@@ -83,8 +81,17 @@ func resetDNS(iface *net.Interface) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// restoreDNS restores the DNS settings of the given interface.
|
||||
// this should only be executed upon turning off the ctrld service.
|
||||
func restoreDNS(iface *net.Interface) (err error) {
|
||||
if ns := savedStaticNameservers(iface); len(ns) > 0 {
|
||||
err = setDNS(iface, ns)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func currentDNS(_ *net.Interface) []string {
|
||||
return resolvconffile.NameServers("")
|
||||
return resolvconffile.NameServers()
|
||||
}
|
||||
|
||||
// currentStaticDNS returns the current static DNS settings of given interface.
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
|
||||
"tailscale.com/control/controlknobs"
|
||||
"tailscale.com/health"
|
||||
"tailscale.com/util/dnsname"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld/internal/dns"
|
||||
"github.com/Control-D-Inc/ctrld/internal/resolvconffile"
|
||||
@@ -50,7 +51,17 @@ func setDNS(iface *net.Interface, nameservers []string) error {
|
||||
ns = append(ns, netip.MustParseAddr(nameserver))
|
||||
}
|
||||
|
||||
if err := r.SetDNS(dns.OSConfig{Nameservers: ns}); err != nil {
|
||||
osConfig := dns.OSConfig{
|
||||
Nameservers: ns,
|
||||
SearchDomains: []dnsname.FQDN{},
|
||||
}
|
||||
if sds, err := searchDomains(); err == nil {
|
||||
osConfig.SearchDomains = sds
|
||||
} else {
|
||||
mainLog.Load().Debug().Err(err).Msg("failed to get search domains list")
|
||||
}
|
||||
|
||||
if err := r.SetDNS(osConfig); err != nil {
|
||||
mainLog.Load().Error().Err(err).Msg("failed to set DNS")
|
||||
return err
|
||||
}
|
||||
@@ -76,8 +87,14 @@ func resetDNS(iface *net.Interface) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// restoreDNS restores the DNS settings of the given interface.
|
||||
// this should only be executed upon turning off the ctrld service.
|
||||
func restoreDNS(iface *net.Interface) (err error) {
|
||||
return err
|
||||
}
|
||||
|
||||
func currentDNS(_ *net.Interface) []string {
|
||||
return resolvconffile.NameServers("")
|
||||
return resolvconffile.NameServers()
|
||||
}
|
||||
|
||||
// currentStaticDNS returns the current static DNS settings of given interface.
|
||||
|
||||
@@ -71,35 +71,31 @@ func setDNS(iface *net.Interface, nameservers []string) error {
|
||||
Nameservers: ns,
|
||||
SearchDomains: []dnsname.FQDN{},
|
||||
}
|
||||
if sds, err := searchDomains(); err == nil {
|
||||
osConfig.SearchDomains = sds
|
||||
} else {
|
||||
mainLog.Load().Debug().Err(err).Msg("failed to get search domains list")
|
||||
}
|
||||
trySystemdResolve := false
|
||||
for i := 0; i < maxSetDNSAttempts; i++ {
|
||||
if err := r.SetDNS(osConfig); err != nil {
|
||||
if strings.Contains(err.Error(), "Rejected send message") &&
|
||||
strings.Contains(err.Error(), "org.freedesktop.network1.Manager") {
|
||||
mainLog.Load().Warn().Msg("Interfaces are managed by systemd-networkd, switch to systemd-resolve for setting DNS")
|
||||
trySystemdResolve = true
|
||||
break
|
||||
}
|
||||
// This error happens on read-only file system, which causes ctrld failed to create backup
|
||||
// for /etc/resolv.conf file. It is ok, because the DNS is still set anyway, and restore
|
||||
// DNS will fallback to use DHCP if there's no backup /etc/resolv.conf file.
|
||||
// The error format is controlled by us, so checking for error string is fine.
|
||||
// See: ../../internal/dns/direct.go:L278
|
||||
if r.Mode() == "direct" && strings.Contains(err.Error(), resolvConfBackupFailedMsg) {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
if err := r.SetDNS(osConfig); err != nil {
|
||||
if strings.Contains(err.Error(), "Rejected send message") &&
|
||||
strings.Contains(err.Error(), "org.freedesktop.network1.Manager") {
|
||||
mainLog.Load().Warn().Msg("Interfaces are managed by systemd-networkd, switch to systemd-resolve for setting DNS")
|
||||
trySystemdResolve = true
|
||||
goto systemdResolve
|
||||
}
|
||||
if useSystemdResolved {
|
||||
if out, err := exec.Command("systemctl", "restart", "systemd-resolved").CombinedOutput(); err != nil {
|
||||
mainLog.Load().Warn().Err(err).Msgf("could not restart systemd-resolved: %s", string(out))
|
||||
}
|
||||
}
|
||||
currentNS := currentDNS(iface)
|
||||
if isSubSet(nameservers, currentNS) {
|
||||
// This error happens on read-only file system, which causes ctrld failed to create backup
|
||||
// for /etc/resolv.conf file. It is ok, because the DNS is still set anyway, and restore
|
||||
// DNS will fallback to use DHCP if there's no backup /etc/resolv.conf file.
|
||||
// The error format is controlled by us, so checking for error string is fine.
|
||||
// See: ../../internal/dns/direct.go:L278
|
||||
if r.Mode() == "direct" && strings.Contains(err.Error(), resolvConfBackupFailedMsg) {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
systemdResolve:
|
||||
if trySystemdResolve {
|
||||
// Stop systemd-networkd and retry setting DNS.
|
||||
if out, err := exec.Command("systemctl", "stop", "systemd-networkd").CombinedOutput(); err != nil {
|
||||
@@ -119,8 +115,8 @@ func setDNS(iface *net.Interface, nameservers []string) error {
|
||||
}
|
||||
time.Sleep(time.Second)
|
||||
}
|
||||
mainLog.Load().Debug().Msg("DNS was not set for some reason")
|
||||
}
|
||||
mainLog.Load().Debug().Msg("DNS was not set for some reason")
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -169,6 +165,7 @@ func resetDNS(iface *net.Interface) (err error) {
|
||||
}
|
||||
|
||||
// TODO(cuonglm): handle DHCPv6 properly.
|
||||
mainLog.Load().Debug().Msg("checking for IPv6 availability")
|
||||
if ctrldnet.IPv6Available(ctx) {
|
||||
c := client6.NewClient()
|
||||
conversation, err := c.Exchange(iface.Name)
|
||||
@@ -188,6 +185,8 @@ func resetDNS(iface *net.Interface) (err error) {
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
mainLog.Load().Debug().Msg("IPv6 is not available")
|
||||
}
|
||||
|
||||
return ignoringEINTR(func() error {
|
||||
@@ -195,8 +194,15 @@ func resetDNS(iface *net.Interface) (err error) {
|
||||
})
|
||||
}
|
||||
|
||||
// restoreDNS restores the DNS settings of the given interface.
|
||||
// this should only be executed upon turning off the ctrld service.
|
||||
func restoreDNS(iface *net.Interface) (err error) {
|
||||
return err
|
||||
}
|
||||
|
||||
func currentDNS(iface *net.Interface) []string {
|
||||
for _, fn := range []getDNS{getDNSByResolvectl, getDNSBySystemdResolved, getDNSByNmcli, resolvconffile.NameServers} {
|
||||
resolvconfFunc := func(_ string) []string { return resolvconffile.NameServers() }
|
||||
for _, fn := range []getDNS{getDNSByResolvectl, getDNSBySystemdResolved, getDNSByNmcli, resolvconfFunc} {
|
||||
if ns := fn(iface.Name); len(ns) > 0 {
|
||||
return ns
|
||||
}
|
||||
|
||||
@@ -1,23 +1,27 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"os/exec"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"golang.org/x/sys/windows"
|
||||
"golang.org/x/sys/windows/registry"
|
||||
"golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
|
||||
|
||||
ctrldnet "github.com/Control-D-Inc/ctrld/internal/net"
|
||||
)
|
||||
|
||||
const (
|
||||
v4InterfaceKeyPathFormat = `HKLM:\SYSTEM\CurrentControlSet\Services\Tcpip\Parameters\Interfaces\`
|
||||
v6InterfaceKeyPathFormat = `HKLM:\SYSTEM\CurrentControlSet\Services\Tcpip6\Parameters\Interfaces\`
|
||||
v4InterfaceKeyPathFormat = `SYSTEM\CurrentControlSet\Services\Tcpip\Parameters\Interfaces\`
|
||||
v6InterfaceKeyPathFormat = `SYSTEM\CurrentControlSet\Services\Tcpip6\Parameters\Interfaces\`
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -30,14 +34,6 @@ func setDnsIgnoreUnusableInterface(iface *net.Interface, nameservers []string) e
|
||||
return setDNS(iface, nameservers)
|
||||
}
|
||||
|
||||
func setDnsPowershellCmd(iface *net.Interface, nameservers []string) string {
|
||||
nss := make([]string, 0, len(nameservers))
|
||||
for _, ns := range nameservers {
|
||||
nss = append(nss, strconv.Quote(ns))
|
||||
}
|
||||
return fmt.Sprintf("Set-DnsClientServerAddress -InterfaceIndex %d -ServerAddresses (%s)", iface.Index, strings.Join(nss, ","))
|
||||
}
|
||||
|
||||
// setDNS sets the dns server for the provided network interface
|
||||
func setDNS(iface *net.Interface, nameservers []string) error {
|
||||
if len(nameservers) == 0 {
|
||||
@@ -46,28 +42,80 @@ func setDNS(iface *net.Interface, nameservers []string) error {
|
||||
setDNSOnce.Do(func() {
|
||||
// If there's a Dns server running, that means we are on AD with Dns feature enabled.
|
||||
// Configuring the Dns server to forward queries to ctrld instead.
|
||||
if windowsHasLocalDnsServerRunning() {
|
||||
if hasLocalDnsServerRunning() {
|
||||
mainLog.Load().Debug().Msg("Local DNS server detected, configuring forwarders")
|
||||
|
||||
file := absHomeDir(windowsForwardersFilename)
|
||||
oldForwardersContent, _ := os.ReadFile(file)
|
||||
mainLog.Load().Debug().Msgf("Using forwarders file: %s", file)
|
||||
|
||||
oldForwardersContent, err := os.ReadFile(file)
|
||||
if err != nil {
|
||||
mainLog.Load().Debug().Err(err).Msg("Could not read existing forwarders file")
|
||||
} else {
|
||||
mainLog.Load().Debug().Msgf("Existing forwarders content: %s", string(oldForwardersContent))
|
||||
}
|
||||
|
||||
hasLocalIPv6Listener := needLocalIPv6Listener()
|
||||
mainLog.Load().Debug().Bool("has_ipv6_listener", hasLocalIPv6Listener).Msg("IPv6 listener status")
|
||||
|
||||
forwarders := slices.DeleteFunc(slices.Clone(nameservers), func(s string) bool {
|
||||
if !hasLocalIPv6Listener {
|
||||
return false
|
||||
}
|
||||
return s == "::1"
|
||||
})
|
||||
mainLog.Load().Debug().Strs("forwarders", forwarders).Msg("Filtered forwarders list")
|
||||
|
||||
if err := os.WriteFile(file, []byte(strings.Join(forwarders, ",")), 0600); err != nil {
|
||||
mainLog.Load().Warn().Err(err).Msg("could not save forwarders settings")
|
||||
} else {
|
||||
mainLog.Load().Debug().Msg("Successfully wrote new forwarders file")
|
||||
}
|
||||
|
||||
oldForwarders := strings.Split(string(oldForwardersContent), ",")
|
||||
mainLog.Load().Debug().Strs("old_forwarders", oldForwarders).Msg("Previous forwarders")
|
||||
|
||||
if err := addDnsServerForwarders(forwarders, oldForwarders); err != nil {
|
||||
mainLog.Load().Warn().Err(err).Msg("could not set forwarders settings")
|
||||
} else {
|
||||
mainLog.Load().Debug().Msg("Successfully configured DNS server forwarders")
|
||||
}
|
||||
}
|
||||
})
|
||||
out, err := powershell(setDnsPowershellCmd(iface, nameservers))
|
||||
luid, err := winipcfg.LUIDFromIndex(uint32(iface.Index))
|
||||
if err != nil {
|
||||
return fmt.Errorf("%w: %s", err, string(out))
|
||||
return fmt.Errorf("setDNS: %w", err)
|
||||
}
|
||||
var (
|
||||
serversV4 []netip.Addr
|
||||
serversV6 []netip.Addr
|
||||
)
|
||||
for _, ns := range nameservers {
|
||||
if addr, err := netip.ParseAddr(ns); err == nil {
|
||||
if addr.Is4() {
|
||||
serversV4 = append(serversV4, addr)
|
||||
} else {
|
||||
serversV6 = append(serversV6, addr)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Note that Windows won't modify the current search domains if passing nil to luid.SetDNS function.
|
||||
// searchDomains is still implemented for Windows just in case Windows API changes in future versions.
|
||||
_ = searchDomains
|
||||
|
||||
if len(serversV4) == 0 && len(serversV6) == 0 {
|
||||
return errors.New("invalid DNS nameservers")
|
||||
}
|
||||
if len(serversV4) > 0 {
|
||||
if err := luid.SetDNS(windows.AF_INET, serversV4, nil); err != nil {
|
||||
return fmt.Errorf("could not set DNS ipv4: %w", err)
|
||||
}
|
||||
}
|
||||
if len(serversV6) > 0 {
|
||||
if err := luid.SetDNS(windows.AF_INET6, serversV6, nil); err != nil {
|
||||
return fmt.Errorf("could not set DNS ipv6: %w", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -81,7 +129,7 @@ func resetDnsIgnoreUnusableInterface(iface *net.Interface) error {
|
||||
func resetDNS(iface *net.Interface) error {
|
||||
resetDNSOnce.Do(func() {
|
||||
// See corresponding comment in setDNS.
|
||||
if windowsHasLocalDnsServerRunning() {
|
||||
if hasLocalDnsServerRunning() {
|
||||
file := absHomeDir(windowsForwardersFilename)
|
||||
content, err := os.ReadFile(file)
|
||||
if err != nil {
|
||||
@@ -96,14 +144,23 @@ func resetDNS(iface *net.Interface) error {
|
||||
}
|
||||
})
|
||||
|
||||
// Restoring DHCP settings.
|
||||
cmd := fmt.Sprintf("Set-DnsClientServerAddress -InterfaceIndex %d -ResetServerAddresses", iface.Index)
|
||||
out, err := powershell(cmd)
|
||||
luid, err := winipcfg.LUIDFromIndex(uint32(iface.Index))
|
||||
if err != nil {
|
||||
return fmt.Errorf("%w: %s", err, string(out))
|
||||
return fmt.Errorf("resetDNS: %w", err)
|
||||
}
|
||||
// Restoring DHCP settings.
|
||||
if err := luid.SetDNS(windows.AF_INET, nil, nil); err != nil {
|
||||
return fmt.Errorf("could not reset DNS ipv4: %w", err)
|
||||
}
|
||||
if err := luid.SetDNS(windows.AF_INET6, nil, nil); err != nil {
|
||||
return fmt.Errorf("could not reset DNS ipv6: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// If there's static DNS saved, restoring it.
|
||||
// restoreDNS restores the DNS settings of the given interface.
|
||||
// this should only be executed upon turning off the ctrld service.
|
||||
func restoreDNS(iface *net.Interface) (err error) {
|
||||
if nss := savedStaticNameservers(iface); len(nss) > 0 {
|
||||
v4ns := make([]string, 0, 2)
|
||||
v6ns := make([]string, 0, 2)
|
||||
@@ -115,17 +172,36 @@ func resetDNS(iface *net.Interface) error {
|
||||
}
|
||||
}
|
||||
|
||||
for _, ns := range [][]string{v4ns, v6ns} {
|
||||
if len(ns) == 0 {
|
||||
continue
|
||||
luid, err := winipcfg.LUIDFromIndex(uint32(iface.Index))
|
||||
if err != nil {
|
||||
return fmt.Errorf("restoreDNS: %w", err)
|
||||
}
|
||||
|
||||
if len(v4ns) > 0 {
|
||||
mainLog.Load().Debug().Msgf("restoring IPv4 static DNS for interface %q: %v", iface.Name, v4ns)
|
||||
if err := setDNS(iface, v4ns); err != nil {
|
||||
return fmt.Errorf("restoreDNS (IPv4): %w", err)
|
||||
}
|
||||
mainLog.Load().Debug().Msgf("setting static DNS for interface %q", iface.Name)
|
||||
if err := setDNS(iface, ns); err != nil {
|
||||
return err
|
||||
} else {
|
||||
mainLog.Load().Debug().Msgf("restoring IPv4 DHCP for interface %q", iface.Name)
|
||||
if err := luid.SetDNS(windows.AF_INET, nil, nil); err != nil {
|
||||
return fmt.Errorf("restoreDNS (IPv4 clear): %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if len(v6ns) > 0 {
|
||||
mainLog.Load().Debug().Msgf("restoring IPv6 static DNS for interface %q: %v", iface.Name, v6ns)
|
||||
if err := setDNS(iface, v6ns); err != nil {
|
||||
return fmt.Errorf("restoreDNS (IPv6): %w", err)
|
||||
}
|
||||
} else {
|
||||
mainLog.Load().Debug().Msgf("restoring IPv6 DHCP for interface %q", iface.Name)
|
||||
if err := luid.SetDNS(windows.AF_INET6, nil, nil); err != nil {
|
||||
return fmt.Errorf("restoreDNS (IPv6 clear): %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
return err
|
||||
}
|
||||
|
||||
func currentDNS(iface *net.Interface) []string {
|
||||
@@ -146,37 +222,69 @@ func currentDNS(iface *net.Interface) []string {
|
||||
return ns
|
||||
}
|
||||
|
||||
// currentStaticDNS returns the current static DNS settings of given interface.
|
||||
// currentStaticDNS checks both the IPv4 and IPv6 paths for static DNS values using keys
|
||||
// like "NameServer" and "ProfileNameServer".
|
||||
func currentStaticDNS(iface *net.Interface) ([]string, error) {
|
||||
luid, err := winipcfg.LUIDFromIndex(uint32(iface.Index))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("fallback winipcfg.LUIDFromIndex: %w", err)
|
||||
}
|
||||
guid, err := luid.GUID()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("fallback luid.GUID: %w", err)
|
||||
}
|
||||
|
||||
var ns []string
|
||||
for _, path := range []string{v4InterfaceKeyPathFormat, v6InterfaceKeyPathFormat} {
|
||||
keyPaths := []string{v4InterfaceKeyPathFormat, v6InterfaceKeyPathFormat}
|
||||
for _, path := range keyPaths {
|
||||
interfaceKeyPath := path + guid.String()
|
||||
found := false
|
||||
for _, key := range []string{"NameServer", "ProfileNameServer"} {
|
||||
if found {
|
||||
continue
|
||||
}
|
||||
cmd := fmt.Sprintf(`Get-ItemPropertyValue -Path "%s" -Name "%s"`, interfaceKeyPath, key)
|
||||
out, err := powershell(cmd)
|
||||
if err == nil && len(out) > 0 {
|
||||
found = true
|
||||
for _, e := range strings.Split(string(out), ",") {
|
||||
ns = append(ns, strings.TrimRight(e, "\x00"))
|
||||
k, err := registry.OpenKey(registry.LOCAL_MACHINE, interfaceKeyPath, registry.QUERY_VALUE)
|
||||
if err != nil {
|
||||
mainLog.Load().Debug().Err(err).Msgf("failed to open registry key %q for interface %q; trying next key", interfaceKeyPath, iface.Name)
|
||||
continue
|
||||
}
|
||||
func() {
|
||||
defer k.Close()
|
||||
for _, keyName := range []string{"NameServer", "ProfileNameServer"} {
|
||||
value, _, err := k.GetStringValue(keyName)
|
||||
if err != nil && !errors.Is(err, registry.ErrNotExist) {
|
||||
mainLog.Load().Debug().Err(err).Msgf("error reading %s registry key", keyName)
|
||||
continue
|
||||
}
|
||||
if len(value) > 0 {
|
||||
mainLog.Load().Debug().Msgf("found static DNS for interface %q: %s", iface.Name, value)
|
||||
parsed := parseDNSServers(value)
|
||||
for _, pns := range parsed {
|
||||
if !slices.Contains(ns, pns) {
|
||||
ns = append(ns, pns)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
if len(ns) == 0 {
|
||||
mainLog.Load().Debug().Msgf("no static DNS values found for interface %q", iface.Name)
|
||||
}
|
||||
return ns, nil
|
||||
}
|
||||
|
||||
// parseDNSServers splits a DNS server string that may be comma- or space-separated,
|
||||
// and trims any extraneous whitespace or null characters.
|
||||
func parseDNSServers(val string) []string {
|
||||
fields := strings.FieldsFunc(val, func(r rune) bool {
|
||||
return r == ' ' || r == ','
|
||||
})
|
||||
var servers []string
|
||||
for _, f := range fields {
|
||||
trimmed := strings.TrimSpace(f)
|
||||
if len(trimmed) > 0 {
|
||||
servers = append(servers, trimmed)
|
||||
}
|
||||
}
|
||||
return servers
|
||||
}
|
||||
|
||||
// addDnsServerForwarders adds given nameservers to DNS server forwarders list,
|
||||
// and also removing old forwarders if provided.
|
||||
func addDnsServerForwarders(nameservers, old []string) error {
|
||||
@@ -216,3 +324,9 @@ func removeDnsServerForwarders(nameservers []string) error {
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// powershell runs the given powershell command.
|
||||
func powershell(cmd string) ([]byte, error) {
|
||||
out, err := exec.Command("powershell", "-Command", cmd).CombinedOutput()
|
||||
return bytes.TrimSpace(out), err
|
||||
}
|
||||
|
||||
68
cmd/cli/os_windows_test.go
Normal file
68
cmd/cli/os_windows_test.go
Normal file
@@ -0,0 +1,68 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"slices"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
|
||||
)
|
||||
|
||||
func Test_currentStaticDNS(t *testing.T) {
|
||||
iface, err := net.InterfaceByName(defaultIfaceName())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
start := time.Now()
|
||||
staticDns, err := currentStaticDNS(iface)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Logf("Using Windows API takes: %d", time.Since(start).Milliseconds())
|
||||
|
||||
start = time.Now()
|
||||
staticDnsPowershell, err := currentStaticDnsPowershell(iface)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Logf("Using Powershell takes: %d", time.Since(start).Milliseconds())
|
||||
|
||||
slices.Sort(staticDns)
|
||||
slices.Sort(staticDnsPowershell)
|
||||
if !slices.Equal(staticDns, staticDnsPowershell) {
|
||||
t.Fatalf("result mismatch, want: %v, got: %v", staticDnsPowershell, staticDns)
|
||||
}
|
||||
}
|
||||
|
||||
func currentStaticDnsPowershell(iface *net.Interface) ([]string, error) {
|
||||
luid, err := winipcfg.LUIDFromIndex(uint32(iface.Index))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
guid, err := luid.GUID()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var ns []string
|
||||
for _, path := range []string{"HKLM:\\" + v4InterfaceKeyPathFormat, "HKLM:\\" + v6InterfaceKeyPathFormat} {
|
||||
interfaceKeyPath := path + guid.String()
|
||||
found := false
|
||||
for _, key := range []string{"NameServer", "ProfileNameServer"} {
|
||||
if found {
|
||||
continue
|
||||
}
|
||||
cmd := fmt.Sprintf(`Get-ItemPropertyValue -Path "%s" -Name "%s"`, interfaceKeyPath, key)
|
||||
out, err := powershell(cmd)
|
||||
if err == nil && len(out) > 0 {
|
||||
found = true
|
||||
for _, e := range strings.Split(string(out), ",") {
|
||||
ns = append(ns, strings.TrimRight(e, "\x00"))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return ns, nil
|
||||
}
|
||||
762
cmd/cli/prog.go
762
cmd/cli/prog.go
File diff suppressed because it is too large
Load Diff
@@ -9,14 +9,12 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/kardianos/service"
|
||||
"tailscale.com/control/controlknobs"
|
||||
"tailscale.com/health"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld/internal/dns"
|
||||
"github.com/Control-D-Inc/ctrld/internal/router"
|
||||
)
|
||||
|
||||
func init() {
|
||||
if r, err := dns.NewOSConfigurator(func(format string, args ...any) {}, &health.Tracker{}, &controlknobs.Knobs{}, "lo"); err == nil {
|
||||
if r, err := newLoopbackOSConfigurator(); err == nil {
|
||||
useSystemdResolved = r.Mode() == "systemd-resolved"
|
||||
}
|
||||
// Disable quic-go's ECN support by default, see https://github.com/quic-go/quic-go/issues/3911
|
||||
@@ -39,6 +37,9 @@ func setDependencies(svc *service.Config) {
|
||||
svc.Dependencies = append(svc.Dependencies, "Wants=systemd-networkd-wait-online.service")
|
||||
}
|
||||
}
|
||||
if routerDeps := router.ServiceDependencies(); len(routerDeps) > 0 {
|
||||
svc.Dependencies = append(svc.Dependencies, routerDeps...)
|
||||
}
|
||||
}
|
||||
|
||||
func setWorkingDirectory(svc *service.Config, dir string) {
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
//go:build !linux && !freebsd && !darwin
|
||||
//go:build !linux && !freebsd && !darwin && !windows
|
||||
|
||||
package cli
|
||||
|
||||
|
||||
@@ -1,11 +1,15 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"runtime"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld"
|
||||
"github.com/Masterminds/semver/v3"
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld"
|
||||
)
|
||||
|
||||
func Test_prog_dnsWatchdogEnabled(t *testing.T) {
|
||||
@@ -55,3 +59,215 @@ func Test_prog_dnsWatchdogInterval(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_shouldUpgrade(t *testing.T) {
|
||||
// Helper function to create a version
|
||||
makeVersion := func(v string) *semver.Version {
|
||||
ver, err := semver.NewVersion(v)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create version %s: %v", v, err)
|
||||
}
|
||||
return ver
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
versionTarget string
|
||||
currentVersion *semver.Version
|
||||
shouldUpgrade bool
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "empty version target",
|
||||
versionTarget: "",
|
||||
currentVersion: makeVersion("v1.0.0"),
|
||||
shouldUpgrade: false,
|
||||
description: "should skip upgrade when version target is empty",
|
||||
},
|
||||
{
|
||||
name: "invalid version target",
|
||||
versionTarget: "invalid-version",
|
||||
currentVersion: makeVersion("v1.0.0"),
|
||||
shouldUpgrade: false,
|
||||
description: "should skip upgrade when version target is invalid",
|
||||
},
|
||||
{
|
||||
name: "same version",
|
||||
versionTarget: "v1.0.0",
|
||||
currentVersion: makeVersion("v1.0.0"),
|
||||
shouldUpgrade: false,
|
||||
description: "should skip upgrade when target version equals current version",
|
||||
},
|
||||
{
|
||||
name: "older version",
|
||||
versionTarget: "v1.0.0",
|
||||
currentVersion: makeVersion("v1.1.0"),
|
||||
shouldUpgrade: false,
|
||||
description: "should skip upgrade when target version is older than current version",
|
||||
},
|
||||
{
|
||||
name: "patch upgrade allowed",
|
||||
versionTarget: "v1.0.1",
|
||||
currentVersion: makeVersion("v1.0.0"),
|
||||
shouldUpgrade: true,
|
||||
description: "should allow patch version upgrade within same major version",
|
||||
},
|
||||
{
|
||||
name: "minor upgrade allowed",
|
||||
versionTarget: "v1.1.0",
|
||||
currentVersion: makeVersion("v1.0.0"),
|
||||
shouldUpgrade: true,
|
||||
description: "should allow minor version upgrade within same major version",
|
||||
},
|
||||
{
|
||||
name: "major upgrade blocked",
|
||||
versionTarget: "v2.0.0",
|
||||
currentVersion: makeVersion("v1.0.0"),
|
||||
shouldUpgrade: false,
|
||||
description: "should block major version upgrade",
|
||||
},
|
||||
{
|
||||
name: "major downgrade blocked",
|
||||
versionTarget: "v1.0.0",
|
||||
currentVersion: makeVersion("v2.0.0"),
|
||||
shouldUpgrade: false,
|
||||
description: "should block major version downgrade",
|
||||
},
|
||||
{
|
||||
name: "version without v prefix",
|
||||
versionTarget: "1.0.1",
|
||||
currentVersion: makeVersion("v1.0.0"),
|
||||
shouldUpgrade: true,
|
||||
description: "should handle version target without v prefix",
|
||||
},
|
||||
{
|
||||
name: "complex version upgrade allowed",
|
||||
versionTarget: "v1.5.3",
|
||||
currentVersion: makeVersion("v1.4.2"),
|
||||
shouldUpgrade: true,
|
||||
description: "should allow complex version upgrade within same major version",
|
||||
},
|
||||
{
|
||||
name: "complex major upgrade blocked",
|
||||
versionTarget: "v3.1.0",
|
||||
currentVersion: makeVersion("v2.5.3"),
|
||||
shouldUpgrade: false,
|
||||
description: "should block complex major version upgrade",
|
||||
},
|
||||
{
|
||||
name: "pre-release version upgrade allowed",
|
||||
versionTarget: "v1.0.1-beta.1",
|
||||
currentVersion: makeVersion("v1.0.0"),
|
||||
shouldUpgrade: true,
|
||||
description: "should allow pre-release version upgrade within same major version",
|
||||
},
|
||||
{
|
||||
name: "pre-release major upgrade blocked",
|
||||
versionTarget: "v2.0.0-alpha.1",
|
||||
currentVersion: makeVersion("v1.0.0"),
|
||||
shouldUpgrade: false,
|
||||
description: "should block pre-release major version upgrade",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// Create test logger
|
||||
testLogger := zerolog.New(zerolog.NewTestWriter(t)).With().Logger()
|
||||
|
||||
// Call the function and capture the result
|
||||
result := shouldUpgrade(tc.versionTarget, tc.currentVersion, &testLogger)
|
||||
|
||||
// Assert the expected result
|
||||
assert.Equal(t, tc.shouldUpgrade, result, tc.description)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_selfUpgradeCheck(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("skipped due to Windows file locking issue on Github Action runners")
|
||||
}
|
||||
|
||||
// Helper function to create a version
|
||||
makeVersion := func(v string) *semver.Version {
|
||||
ver, err := semver.NewVersion(v)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create version %s: %v", v, err)
|
||||
}
|
||||
return ver
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
versionTarget string
|
||||
currentVersion *semver.Version
|
||||
shouldUpgrade bool
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "upgrade allowed",
|
||||
versionTarget: "v1.0.1",
|
||||
currentVersion: makeVersion("v1.0.0"),
|
||||
shouldUpgrade: true,
|
||||
description: "should allow upgrade and attempt to perform it",
|
||||
},
|
||||
{
|
||||
name: "upgrade blocked",
|
||||
versionTarget: "v2.0.0",
|
||||
currentVersion: makeVersion("v1.0.0"),
|
||||
shouldUpgrade: false,
|
||||
description: "should block upgrade and not attempt to perform it",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// Create test logger
|
||||
testLogger := zerolog.New(zerolog.NewTestWriter(t)).With().Logger()
|
||||
|
||||
// Call the function and capture the result
|
||||
result := selfUpgradeCheck(tc.versionTarget, tc.currentVersion, &testLogger)
|
||||
|
||||
// Assert the expected result
|
||||
assert.Equal(t, tc.shouldUpgrade, result, tc.description)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_performUpgrade(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("skipped due to Windows file locking issue on Github Action runners")
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
versionTarget string
|
||||
expectedResult bool
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "valid version target",
|
||||
versionTarget: "v1.0.1",
|
||||
expectedResult: true,
|
||||
description: "should attempt to perform upgrade with valid version target",
|
||||
},
|
||||
{
|
||||
name: "empty version target",
|
||||
versionTarget: "",
|
||||
expectedResult: true,
|
||||
description: "should attempt to perform upgrade even with empty version target",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// Call the function and capture the result
|
||||
result := performUpgrade(tc.versionTarget)
|
||||
assert.Equal(t, tc.expectedResult, result, tc.description)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
14
cmd/cli/prog_windows.go
Normal file
14
cmd/cli/prog_windows.go
Normal file
@@ -0,0 +1,14 @@
|
||||
package cli
|
||||
|
||||
import "github.com/kardianos/service"
|
||||
|
||||
func setDependencies(svc *service.Config) {
|
||||
if hasLocalDnsServerRunning() {
|
||||
svc.Dependencies = []string{"DNS"}
|
||||
}
|
||||
}
|
||||
|
||||
func setWorkingDirectory(svc *service.Config, dir string) {
|
||||
// WorkingDirectory is not supported on Windows.
|
||||
svc.WorkingDirectory = dir
|
||||
}
|
||||
@@ -3,11 +3,38 @@ package cli
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/fsnotify/fsnotify"
|
||||
)
|
||||
|
||||
// parseResolvConfNameservers reads the resolv.conf file and returns the nameservers found.
|
||||
// Returns nil if no nameservers are found.
|
||||
func (p *prog) parseResolvConfNameservers(path string) ([]string, error) {
|
||||
content, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Parse the file for "nameserver" lines
|
||||
var currentNS []string
|
||||
lines := strings.Split(string(content), "\n")
|
||||
for _, line := range lines {
|
||||
trimmed := strings.TrimSpace(line)
|
||||
if strings.HasPrefix(trimmed, "nameserver") {
|
||||
parts := strings.Fields(trimmed)
|
||||
if len(parts) >= 2 {
|
||||
currentNS = append(currentNS, parts[1])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return currentNS, nil
|
||||
}
|
||||
|
||||
// watchResolvConf watches any changes to /etc/resolv.conf file,
|
||||
// and reverting to the original config set by ctrld.
|
||||
func (p *prog) watchResolvConf(iface *net.Interface, ns []netip.Addr, setDnsFn func(iface *net.Interface, ns []netip.Addr) error) {
|
||||
@@ -40,7 +67,7 @@ func (p *prog) watchResolvConf(iface *net.Interface, ns []netip.Addr, setDnsFn f
|
||||
mainLog.Load().Debug().Msgf("stopping watcher for %s", resolvConfPath)
|
||||
return
|
||||
case event, ok := <-watcher.Events:
|
||||
if p.leakingQuery.Load() {
|
||||
if p.recoveryRunning.Load() {
|
||||
return
|
||||
}
|
||||
if !ok {
|
||||
@@ -50,17 +77,81 @@ func (p *prog) watchResolvConf(iface *net.Interface, ns []netip.Addr, setDnsFn f
|
||||
continue
|
||||
}
|
||||
if event.Has(fsnotify.Write) || event.Has(fsnotify.Create) {
|
||||
mainLog.Load().Debug().Msg("/etc/resolv.conf changes detected, reverting to ctrld setting")
|
||||
if err := watcher.Remove(watchDir); err != nil {
|
||||
mainLog.Load().Error().Err(err).Msg("failed to pause watcher")
|
||||
continue
|
||||
mainLog.Load().Debug().Msgf("/etc/resolv.conf changes detected, reading changes...")
|
||||
|
||||
// Convert expected nameservers to strings for comparison
|
||||
expectedNS := make([]string, len(ns))
|
||||
for i, addr := range ns {
|
||||
expectedNS[i] = addr.String()
|
||||
}
|
||||
if err := setDnsFn(iface, ns); err != nil {
|
||||
mainLog.Load().Error().Err(err).Msg("failed to revert /etc/resolv.conf changes")
|
||||
|
||||
var foundNS []string
|
||||
var err error
|
||||
|
||||
maxRetries := 1
|
||||
for retry := 0; retry < maxRetries; retry++ {
|
||||
foundNS, err = p.parseResolvConfNameservers(resolvConfPath)
|
||||
if err != nil {
|
||||
mainLog.Load().Error().Err(err).Msg("failed to read resolv.conf content")
|
||||
break
|
||||
}
|
||||
|
||||
// If we found nameservers, break out of retry loop
|
||||
if len(foundNS) > 0 {
|
||||
break
|
||||
}
|
||||
|
||||
// Only retry if we found no nameservers
|
||||
if retry < maxRetries-1 {
|
||||
mainLog.Load().Debug().Msgf("resolv.conf has no nameserver entries, retry %d/%d in 2 seconds", retry+1, maxRetries)
|
||||
select {
|
||||
case <-p.stopCh:
|
||||
return
|
||||
case <-p.dnsWatcherStopCh:
|
||||
return
|
||||
case <-time.After(2 * time.Second):
|
||||
continue
|
||||
}
|
||||
} else {
|
||||
mainLog.Load().Debug().Msg("resolv.conf remained empty after all retries")
|
||||
}
|
||||
}
|
||||
if err := watcher.Add(watchDir); err != nil {
|
||||
mainLog.Load().Error().Err(err).Msg("failed to continue running watcher")
|
||||
return
|
||||
|
||||
// If we found nameservers, check if they match what we expect
|
||||
if len(foundNS) > 0 {
|
||||
// Check if the nameservers match exactly what we expect
|
||||
matches := len(foundNS) == len(expectedNS)
|
||||
if matches {
|
||||
for i := range foundNS {
|
||||
if foundNS[i] != expectedNS[i] {
|
||||
matches = false
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
mainLog.Load().Debug().
|
||||
Strs("found", foundNS).
|
||||
Strs("expected", expectedNS).
|
||||
Bool("matches", matches).
|
||||
Msg("checking nameservers")
|
||||
|
||||
// Only revert if the nameservers don't match
|
||||
if !matches {
|
||||
if err := watcher.Remove(watchDir); err != nil {
|
||||
mainLog.Load().Error().Err(err).Msg("failed to pause watcher")
|
||||
continue
|
||||
}
|
||||
|
||||
if err := setDnsFn(iface, ns); err != nil {
|
||||
mainLog.Load().Error().Err(err).Msg("failed to revert /etc/resolv.conf changes")
|
||||
}
|
||||
|
||||
if err := watcher.Add(watchDir); err != nil {
|
||||
mainLog.Load().Error().Err(err).Msg("failed to continue running watcher")
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
case err, ok := <-watcher.Errors:
|
||||
|
||||
@@ -13,9 +13,9 @@ import (
|
||||
"github.com/Control-D-Inc/ctrld/internal/dns"
|
||||
)
|
||||
|
||||
// setResolvConf sets the content of resolv.conf file using the given nameservers list.
|
||||
// setResolvConf sets the content of the resolv.conf file using the given nameservers list.
|
||||
func setResolvConf(iface *net.Interface, ns []netip.Addr) error {
|
||||
r, err := dns.NewOSConfigurator(func(format string, args ...any) {}, &health.Tracker{}, &controlknobs.Knobs{}, "lo") // interface name does not matter.
|
||||
r, err := newLoopbackOSConfigurator()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -24,12 +24,17 @@ func setResolvConf(iface *net.Interface, ns []netip.Addr) error {
|
||||
Nameservers: ns,
|
||||
SearchDomains: []dnsname.FQDN{},
|
||||
}
|
||||
if sds, err := searchDomains(); err == nil {
|
||||
oc.SearchDomains = sds
|
||||
} else {
|
||||
mainLog.Load().Debug().Err(err).Msg("failed to get search domains list when reverting resolv.conf file")
|
||||
}
|
||||
return r.SetDNS(oc)
|
||||
}
|
||||
|
||||
// shouldWatchResolvconf reports whether ctrld should watch changes to resolv.conf file with given OS configurator.
|
||||
func shouldWatchResolvconf() bool {
|
||||
r, err := dns.NewOSConfigurator(func(format string, args ...any) {}, &health.Tracker{}, &controlknobs.Knobs{}, "lo") // interface name does not matter.
|
||||
r, err := newLoopbackOSConfigurator()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
@@ -40,3 +45,8 @@ func shouldWatchResolvconf() bool {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// newLoopbackOSConfigurator creates an OSConfigurator for DNS management using the "lo" interface.
|
||||
func newLoopbackOSConfigurator() (dns.OSConfigurator, error) {
|
||||
return dns.NewOSConfigurator(noopLogf, &health.Tracker{}, &controlknobs.Knobs{}, "lo")
|
||||
}
|
||||
|
||||
14
cmd/cli/search_domains_unix.go
Normal file
14
cmd/cli/search_domains_unix.go
Normal file
@@ -0,0 +1,14 @@
|
||||
//go:build unix
|
||||
|
||||
package cli
|
||||
|
||||
import (
|
||||
"tailscale.com/util/dnsname"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld/internal/resolvconffile"
|
||||
)
|
||||
|
||||
// searchDomains returns the current search domains config.
|
||||
func searchDomains() ([]dnsname.FQDN, error) {
|
||||
return resolvconffile.SearchDomains()
|
||||
}
|
||||
43
cmd/cli/search_domains_windows.go
Normal file
43
cmd/cli/search_domains_windows.go
Normal file
@@ -0,0 +1,43 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"syscall"
|
||||
|
||||
"golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
|
||||
"tailscale.com/util/dnsname"
|
||||
)
|
||||
|
||||
// searchDomains returns the current search domains config.
|
||||
func searchDomains() ([]dnsname.FQDN, error) {
|
||||
flags := winipcfg.GAAFlagIncludeGateways |
|
||||
winipcfg.GAAFlagIncludePrefix
|
||||
|
||||
aas, err := winipcfg.GetAdaptersAddresses(syscall.AF_UNSPEC, flags)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("winipcfg.GetAdaptersAddresses: %w", err)
|
||||
}
|
||||
|
||||
var sds []dnsname.FQDN
|
||||
for _, aa := range aas {
|
||||
if aa.OperStatus != winipcfg.IfOperStatusUp {
|
||||
continue
|
||||
}
|
||||
|
||||
// Skip if software loopback or other non-physical types
|
||||
// This is to avoid the "Loopback Pseudo-Interface 1" issue we see on windows
|
||||
if aa.IfType == winipcfg.IfTypeSoftwareLoopback {
|
||||
continue
|
||||
}
|
||||
|
||||
for a := aa.FirstDNSSuffix; a != nil; a = a.Next {
|
||||
d, err := dnsname.ToFQDN(a.String())
|
||||
if err != nil {
|
||||
mainLog.Load().Debug().Err(err).Msgf("failed to parse domain: %s", a.String())
|
||||
continue
|
||||
}
|
||||
sds = append(sds, d)
|
||||
}
|
||||
}
|
||||
return sds, nil
|
||||
}
|
||||
@@ -22,7 +22,7 @@ func selfUninstall(p *prog, logger zerolog.Logger) {
|
||||
logger.Fatal().Err(err).Msg("could not determine executable")
|
||||
}
|
||||
args := []string{"uninstall"}
|
||||
if !deactivationPinNotSet() {
|
||||
if deactivationPinSet() {
|
||||
args = append(args, fmt.Sprintf("--pin=%d", cdDeactivationPin.Load()))
|
||||
}
|
||||
cmd := exec.Command(bin, args...)
|
||||
|
||||
12
cmd/cli/self_upgrade_others.go
Normal file
12
cmd/cli/self_upgrade_others.go
Normal file
@@ -0,0 +1,12 @@
|
||||
//go:build !windows
|
||||
|
||||
package cli
|
||||
|
||||
import (
|
||||
"syscall"
|
||||
)
|
||||
|
||||
// sysProcAttrForDetachedChildProcess returns *syscall.SysProcAttr instance for running a detached child command.
|
||||
func sysProcAttrForDetachedChildProcess() *syscall.SysProcAttr {
|
||||
return &syscall.SysProcAttr{Setsid: true}
|
||||
}
|
||||
18
cmd/cli/self_upgrade_windows.go
Normal file
18
cmd/cli/self_upgrade_windows.go
Normal file
@@ -0,0 +1,18 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"syscall"
|
||||
)
|
||||
|
||||
// From: https://learn.microsoft.com/en-us/windows/win32/procthread/process-creation-flags?redirectedfrom=MSDN
|
||||
|
||||
// SYSCALL_CREATE_NO_WINDOW set flag to run process without a console window.
|
||||
const SYSCALL_CREATE_NO_WINDOW = 0x08000000
|
||||
|
||||
// sysProcAttrForDetachedChildProcess returns *syscall.SysProcAttr instance for running self-upgrade command.
|
||||
func sysProcAttrForDetachedChildProcess() *syscall.SysProcAttr {
|
||||
return &syscall.SysProcAttr{
|
||||
CreationFlags: syscall.CREATE_NEW_PROCESS_GROUP | SYSCALL_CREATE_NO_WINDOW,
|
||||
HideWindow: true,
|
||||
}
|
||||
}
|
||||
@@ -4,12 +4,16 @@ import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"os/exec"
|
||||
"runtime"
|
||||
|
||||
"github.com/coreos/go-systemd/v22/unit"
|
||||
"github.com/kardianos/service"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld/internal/router"
|
||||
"github.com/Control-D-Inc/ctrld/internal/router/openwrt"
|
||||
)
|
||||
|
||||
// newService wraps service.New call to return service.Service
|
||||
@@ -130,6 +134,59 @@ func (s *systemd) Status() (service.Status, error) {
|
||||
return s.Service.Status()
|
||||
}
|
||||
|
||||
func (s *systemd) Start() error {
|
||||
const systemdUnitFile = "/etc/systemd/system/ctrld.service"
|
||||
f, err := os.Open(systemdUnitFile)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer f.Close()
|
||||
if opts, change := ensureSystemdKillMode(f); change {
|
||||
mode := os.FileMode(0644)
|
||||
buf, err := io.ReadAll(unit.Serialize(opts))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := os.WriteFile(systemdUnitFile, buf, mode); err != nil {
|
||||
return err
|
||||
}
|
||||
if out, err := exec.Command("systemctl", "daemon-reload").CombinedOutput(); err != nil {
|
||||
return fmt.Errorf("systemctl daemon-reload failed: %w\n%s", err, string(out))
|
||||
}
|
||||
mainLog.Load().Debug().Msg("set KillMode=process successfully")
|
||||
}
|
||||
return s.Service.Start()
|
||||
}
|
||||
|
||||
// ensureSystemdKillMode ensure systemd unit file is configured with KillMode=process.
|
||||
// This is necessary for running self-upgrade flow.
|
||||
func ensureSystemdKillMode(r io.Reader) (opts []*unit.UnitOption, change bool) {
|
||||
opts, err := unit.DeserializeOptions(r)
|
||||
if err != nil {
|
||||
mainLog.Load().Error().Err(err).Msg("failed to deserialize options")
|
||||
return
|
||||
}
|
||||
change = true
|
||||
needKillModeOpt := true
|
||||
killModeOpt := unit.NewUnitOption("Service", "KillMode", "process")
|
||||
for _, opt := range opts {
|
||||
if opt.Match(killModeOpt) {
|
||||
needKillModeOpt = false
|
||||
change = false
|
||||
break
|
||||
}
|
||||
if opt.Section == killModeOpt.Section && opt.Name == killModeOpt.Name {
|
||||
opt.Value = killModeOpt.Value
|
||||
needKillModeOpt = false
|
||||
break
|
||||
}
|
||||
}
|
||||
if needKillModeOpt {
|
||||
opts = append(opts, killModeOpt)
|
||||
}
|
||||
return opts, change
|
||||
}
|
||||
|
||||
func newLaunchd(s service.Service) *launchd {
|
||||
return &launchd{
|
||||
Service: s,
|
||||
@@ -156,17 +213,22 @@ func (l *launchd) Status() (service.Status, error) {
|
||||
type task struct {
|
||||
f func() error
|
||||
abortOnError bool
|
||||
Name string
|
||||
}
|
||||
|
||||
func doTasks(tasks []task) bool {
|
||||
var prevErr error
|
||||
for _, task := range tasks {
|
||||
mainLog.Load().Debug().Msgf("Running task %s", task.Name)
|
||||
if err := task.f(); err != nil {
|
||||
if task.abortOnError {
|
||||
mainLog.Load().Error().Msg(errors.Join(prevErr, err).Error())
|
||||
mainLog.Load().Error().Msgf("error running task %s: %v", task.Name, err)
|
||||
return false
|
||||
}
|
||||
prevErr = err
|
||||
// if this is darwin stop command, dont print debug
|
||||
// since launchctl complains on every start
|
||||
if runtime.GOOS != "darwin" || task.Name != "Stop" {
|
||||
mainLog.Load().Debug().Msgf("error running task %s: %v", task.Name, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
return true
|
||||
@@ -187,6 +249,13 @@ func checkHasElevatedPrivilege() {
|
||||
func unixSystemVServiceStatus() (service.Status, error) {
|
||||
out, err := exec.Command("/etc/init.d/ctrld", "status").CombinedOutput()
|
||||
if err != nil {
|
||||
// Specific case for openwrt >= 24.10, it returns non-success code
|
||||
// for above status command, which may not right.
|
||||
if router.Name() == openwrt.Name {
|
||||
if string(bytes.ToLower(bytes.TrimSpace(out))) == "inactive" {
|
||||
return service.StatusStopped, nil
|
||||
}
|
||||
}
|
||||
return service.StatusUnknown, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -13,3 +13,10 @@ func hasElevatedPrivilege() (bool, error) {
|
||||
func openLogFile(path string, flags int) (*os.File, error) {
|
||||
return os.OpenFile(path, flags, os.FileMode(0o600))
|
||||
}
|
||||
|
||||
// hasLocalDnsServerRunning reports whether we are on Windows and having Dns server running.
|
||||
func hasLocalDnsServerRunning() bool { return false }
|
||||
|
||||
func ConfigureWindowsServiceFailureActions(serviceName string) error { return nil }
|
||||
|
||||
func isRunningOnDomainControllerWindows() (bool, int) { return false, 0 }
|
||||
|
||||
28
cmd/cli/service_test.go
Normal file
28
cmd/cli/service_test.go
Normal file
@@ -0,0 +1,28 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func Test_ensureSystemdKillMode(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
unitFile string
|
||||
wantChange bool
|
||||
}{
|
||||
{"no KillMode", "[Service]\nExecStart=/bin/sleep 1", true},
|
||||
{"not KillMode=process", "[Service]\nExecStart=/bin/sleep 1\nKillMode=mixed", true},
|
||||
{"KillMode=process", "[Service]\nExecStart=/bin/sleep 1\nKillMode=process", false},
|
||||
{"invalid unit file", "[Service\nExecStart=/bin/sleep 1\nKillMode=process", false},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
if _, change := ensureSystemdKillMode(strings.NewReader(tc.unitFile)); tc.wantChange != change {
|
||||
t.Errorf("ensureSystemdKillMode(%q) = %v, want %v", tc.unitFile, change, tc.wantChange)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -2,9 +2,20 @@ package cli
|
||||
|
||||
import (
|
||||
"os"
|
||||
"reflect"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
"unsafe"
|
||||
|
||||
"github.com/microsoft/wmi/pkg/base/host"
|
||||
"github.com/microsoft/wmi/pkg/base/instance"
|
||||
"github.com/microsoft/wmi/pkg/base/query"
|
||||
"github.com/microsoft/wmi/pkg/constant"
|
||||
"golang.org/x/sys/windows"
|
||||
"golang.org/x/sys/windows/svc/mgr"
|
||||
)
|
||||
|
||||
func hasElevatedPrivilege() (bool, error) {
|
||||
@@ -28,6 +39,67 @@ func hasElevatedPrivilege() (bool, error) {
|
||||
return token.IsMember(sid)
|
||||
}
|
||||
|
||||
// ConfigureWindowsServiceFailureActions checks if the given service
|
||||
// has the correct failure actions configured, and updates them if not.
|
||||
func ConfigureWindowsServiceFailureActions(serviceName string) error {
|
||||
if runtime.GOOS != "windows" {
|
||||
return nil // no-op on non-Windows
|
||||
}
|
||||
|
||||
m, err := mgr.Connect()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer m.Disconnect()
|
||||
|
||||
s, err := m.OpenService(serviceName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer s.Close()
|
||||
|
||||
// 1. Retrieve the current config
|
||||
cfg, err := s.Config()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 2. Update the Description
|
||||
cfg.Description = "A highly configurable, multi-protocol DNS forwarding proxy"
|
||||
|
||||
// 3. Apply the updated config
|
||||
if err := s.UpdateConfig(cfg); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Then proceed with existing actions, e.g. setting failure actions
|
||||
actions := []mgr.RecoveryAction{
|
||||
{Type: mgr.ServiceRestart, Delay: time.Second * 5}, // 5 seconds
|
||||
{Type: mgr.ServiceRestart, Delay: time.Second * 5}, // 5 seconds
|
||||
{Type: mgr.ServiceRestart, Delay: time.Second * 5}, // 5 seconds
|
||||
}
|
||||
|
||||
// Set the recovery actions (3 restarts, reset period = 120).
|
||||
err = s.SetRecoveryActions(actions, 120)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Ensure that failure actions are NOT triggered on user-initiated stops.
|
||||
var failureActionsFlag windows.SERVICE_FAILURE_ACTIONS_FLAG
|
||||
failureActionsFlag.FailureActionsOnNonCrashFailures = 0
|
||||
|
||||
if err := windows.ChangeServiceConfig2(
|
||||
s.Handle,
|
||||
windows.SERVICE_CONFIG_FAILURE_ACTIONS_FLAG,
|
||||
(*byte)(unsafe.Pointer(&failureActionsFlag)),
|
||||
); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func openLogFile(path string, mode int) (*os.File, error) {
|
||||
if len(path) == 0 {
|
||||
return nil, &os.PathError{Path: path, Op: "open", Err: syscall.ERROR_FILE_NOT_FOUND}
|
||||
@@ -79,3 +151,77 @@ func openLogFile(path string, mode int) (*os.File, error) {
|
||||
|
||||
return os.NewFile(uintptr(handle), path), nil
|
||||
}
|
||||
|
||||
const processEntrySize = uint32(unsafe.Sizeof(windows.ProcessEntry32{}))
|
||||
|
||||
// hasLocalDnsServerRunning reports whether we are on Windows and having Dns server running.
|
||||
func hasLocalDnsServerRunning() bool {
|
||||
h, e := windows.CreateToolhelp32Snapshot(windows.TH32CS_SNAPPROCESS, 0)
|
||||
if e != nil {
|
||||
return false
|
||||
}
|
||||
p := windows.ProcessEntry32{Size: processEntrySize}
|
||||
for {
|
||||
e := windows.Process32Next(h, &p)
|
||||
if e != nil {
|
||||
return false
|
||||
}
|
||||
if strings.ToLower(windows.UTF16ToString(p.ExeFile[:])) == "dns.exe" {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func isRunningOnDomainControllerWindows() (bool, int) {
|
||||
whost := host.NewWmiLocalHost()
|
||||
q := query.NewWmiQuery("Win32_ComputerSystem")
|
||||
instances, err := instance.GetWmiInstancesFromHost(whost, string(constant.CimV2), q)
|
||||
if err != nil {
|
||||
mainLog.Load().Debug().Err(err).Msg("WMI query failed")
|
||||
return false, 0
|
||||
}
|
||||
if instances == nil {
|
||||
mainLog.Load().Debug().Msg("WMI query returned nil instances")
|
||||
return false, 0
|
||||
}
|
||||
defer instances.Close()
|
||||
|
||||
if len(instances) == 0 {
|
||||
mainLog.Load().Debug().Msg("no rows returned from Win32_ComputerSystem")
|
||||
return false, 0
|
||||
}
|
||||
|
||||
val, err := instances[0].GetProperty("DomainRole")
|
||||
if err != nil {
|
||||
mainLog.Load().Debug().Err(err).Msg("failed to get DomainRole property")
|
||||
return false, 0
|
||||
}
|
||||
if val == nil {
|
||||
mainLog.Load().Debug().Msg("DomainRole property is nil")
|
||||
return false, 0
|
||||
}
|
||||
|
||||
// Safely handle varied types: string or integer
|
||||
var roleInt int
|
||||
switch v := val.(type) {
|
||||
case string:
|
||||
// "4", "5", etc.
|
||||
parsed, parseErr := strconv.Atoi(v)
|
||||
if parseErr != nil {
|
||||
mainLog.Load().Debug().Err(parseErr).Msgf("failed to parse DomainRole value %q", v)
|
||||
return false, 0
|
||||
}
|
||||
roleInt = parsed
|
||||
case int8, int16, int32, int64:
|
||||
roleInt = int(reflect.ValueOf(v).Int())
|
||||
case uint8, uint16, uint32, uint64:
|
||||
roleInt = int(reflect.ValueOf(v).Uint())
|
||||
default:
|
||||
mainLog.Load().Debug().Msgf("unexpected DomainRole type: %T value=%v", v, v)
|
||||
return false, 0
|
||||
}
|
||||
|
||||
// Check if role indicates a domain controller
|
||||
isDC := roleInt == BackupDomainController || roleInt == PrimaryDomainController
|
||||
return isDC, roleInt
|
||||
}
|
||||
|
||||
25
cmd/cli/service_windows_test.go
Normal file
25
cmd/cli/service_windows_test.go
Normal file
@@ -0,0 +1,25 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func Test_hasLocalDnsServerRunning(t *testing.T) {
|
||||
start := time.Now()
|
||||
hasDns := hasLocalDnsServerRunning()
|
||||
t.Logf("Using Windows API takes: %d", time.Since(start).Milliseconds())
|
||||
|
||||
start = time.Now()
|
||||
hasDnsPowershell := hasLocalDnsServerRunningPowershell()
|
||||
t.Logf("Using Powershell takes: %d", time.Since(start).Milliseconds())
|
||||
|
||||
if hasDns != hasDnsPowershell {
|
||||
t.Fatalf("result mismatch, want: %v, got: %v", hasDnsPowershell, hasDns)
|
||||
}
|
||||
}
|
||||
|
||||
func hasLocalDnsServerRunningPowershell() bool {
|
||||
_, err := powershell("Get-Process -Name DNS")
|
||||
return err == nil
|
||||
}
|
||||
@@ -1,18 +1,15 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld"
|
||||
)
|
||||
|
||||
const (
|
||||
// maxFailureRequest is the maximum failed queries allowed before an upstream is marked as down.
|
||||
maxFailureRequest = 100
|
||||
maxFailureRequest = 50
|
||||
// checkUpstreamBackoffSleep is the time interval between each upstream checks.
|
||||
checkUpstreamBackoffSleep = 2 * time.Second
|
||||
)
|
||||
@@ -21,18 +18,24 @@ const (
|
||||
type upstreamMonitor struct {
|
||||
cfg *ctrld.Config
|
||||
|
||||
mu sync.Mutex
|
||||
mu sync.RWMutex
|
||||
checking map[string]bool
|
||||
down map[string]bool
|
||||
failureReq map[string]uint64
|
||||
recovered map[string]bool
|
||||
|
||||
// failureTimerActive tracks if a timer is already running for a given upstream.
|
||||
failureTimerActive map[string]bool
|
||||
}
|
||||
|
||||
func newUpstreamMonitor(cfg *ctrld.Config) *upstreamMonitor {
|
||||
um := &upstreamMonitor{
|
||||
cfg: cfg,
|
||||
checking: make(map[string]bool),
|
||||
down: make(map[string]bool),
|
||||
failureReq: make(map[string]uint64),
|
||||
cfg: cfg,
|
||||
checking: make(map[string]bool),
|
||||
down: make(map[string]bool),
|
||||
failureReq: make(map[string]uint64),
|
||||
recovered: make(map[string]bool),
|
||||
failureTimerActive: make(map[string]bool),
|
||||
}
|
||||
for n := range cfg.Upstream {
|
||||
upstream := upstreamPrefix + n
|
||||
@@ -42,14 +45,47 @@ func newUpstreamMonitor(cfg *ctrld.Config) *upstreamMonitor {
|
||||
return um
|
||||
}
|
||||
|
||||
// increaseFailureCount increase failed queries count for an upstream by 1.
|
||||
// increaseFailureCount increases failed queries count for an upstream by 1 and logs debug information.
|
||||
// It uses a timer to debounce failure detection, ensuring that an upstream is marked as down
|
||||
// within 10 seconds if failures persist, without spawning duplicate goroutines.
|
||||
func (um *upstreamMonitor) increaseFailureCount(upstream string) {
|
||||
um.mu.Lock()
|
||||
defer um.mu.Unlock()
|
||||
|
||||
if um.recovered[upstream] {
|
||||
mainLog.Load().Debug().Msgf("upstream %q is recovered, skipping failure count increase", upstream)
|
||||
return
|
||||
}
|
||||
|
||||
um.failureReq[upstream] += 1
|
||||
failedCount := um.failureReq[upstream]
|
||||
um.down[upstream] = failedCount >= maxFailureRequest
|
||||
|
||||
// Log the updated failure count.
|
||||
mainLog.Load().Debug().Msgf("upstream %q failure count updated to %d", upstream, failedCount)
|
||||
|
||||
// If this is the first failure and no timer is running, start a 10-second timer.
|
||||
if failedCount == 1 && !um.failureTimerActive[upstream] {
|
||||
um.failureTimerActive[upstream] = true
|
||||
go func(upstream string) {
|
||||
time.Sleep(10 * time.Second)
|
||||
um.mu.Lock()
|
||||
defer um.mu.Unlock()
|
||||
// If no success occurred during the 10-second window (i.e. counter remains > 0)
|
||||
// and the upstream is not in a recovered state, mark it as down.
|
||||
if um.failureReq[upstream] > 0 && !um.recovered[upstream] {
|
||||
um.down[upstream] = true
|
||||
mainLog.Load().Warn().Msgf("upstream %q marked as down after 10 seconds (failure count: %d)", upstream, um.failureReq[upstream])
|
||||
}
|
||||
// Reset the timer flag so that a new timer can be spawned if needed.
|
||||
um.failureTimerActive[upstream] = false
|
||||
}(upstream)
|
||||
}
|
||||
|
||||
// If the failure count quickly reaches the threshold, mark the upstream as down immediately.
|
||||
if failedCount >= maxFailureRequest {
|
||||
um.down[upstream] = true
|
||||
mainLog.Load().Warn().Msgf("upstream %q marked as down immediately (failure count: %d)", upstream, failedCount)
|
||||
}
|
||||
}
|
||||
|
||||
// isDown reports whether the given upstream is being marked as down.
|
||||
@@ -63,56 +99,28 @@ func (um *upstreamMonitor) isDown(upstream string) bool {
|
||||
// reset marks an upstream as up and set failed queries counter to zero.
|
||||
func (um *upstreamMonitor) reset(upstream string) {
|
||||
um.mu.Lock()
|
||||
defer um.mu.Unlock()
|
||||
|
||||
um.failureReq[upstream] = 0
|
||||
um.down[upstream] = false
|
||||
}
|
||||
|
||||
// checkUpstream checks the given upstream status, periodically sending query to upstream
|
||||
// until successfully. An upstream status/counter will be reset once it becomes reachable.
|
||||
func (p *prog) checkUpstream(upstream string, uc *ctrld.UpstreamConfig) {
|
||||
p.um.mu.Lock()
|
||||
isChecking := p.um.checking[upstream]
|
||||
if isChecking {
|
||||
p.um.mu.Unlock()
|
||||
return
|
||||
}
|
||||
p.um.checking[upstream] = true
|
||||
p.um.mu.Unlock()
|
||||
defer func() {
|
||||
p.um.mu.Lock()
|
||||
p.um.checking[upstream] = false
|
||||
p.um.mu.Unlock()
|
||||
um.recovered[upstream] = true
|
||||
um.mu.Unlock()
|
||||
go func() {
|
||||
// debounce the recovery to avoid incrementing failure counts already in flight
|
||||
time.Sleep(1 * time.Second)
|
||||
um.mu.Lock()
|
||||
um.recovered[upstream] = false
|
||||
um.mu.Unlock()
|
||||
}()
|
||||
|
||||
resolver, err := ctrld.NewResolver(uc)
|
||||
if err != nil {
|
||||
mainLog.Load().Warn().Err(err).Msg("could not check upstream")
|
||||
return
|
||||
}
|
||||
msg := new(dns.Msg)
|
||||
msg.SetQuestion(".", dns.TypeNS)
|
||||
|
||||
check := func() error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel()
|
||||
uc.ReBootstrap()
|
||||
_, err := resolver.Resolve(ctx, msg)
|
||||
return err
|
||||
}
|
||||
for {
|
||||
if err := check(); err == nil {
|
||||
mainLog.Load().Debug().Msgf("upstream %q is online", uc.Endpoint)
|
||||
p.um.reset(upstream)
|
||||
if p.leakingQuery.CompareAndSwap(true, false) {
|
||||
p.leakingQueryMu.Lock()
|
||||
p.leakingQueryWasRun = false
|
||||
p.leakingQueryMu.Unlock()
|
||||
mainLog.Load().Warn().Msg("stop leaking query")
|
||||
}
|
||||
return
|
||||
}
|
||||
time.Sleep(checkUpstreamBackoffSleep)
|
||||
}
|
||||
}
|
||||
|
||||
// countHealthy returns the number of upstreams in the provided map that are considered healthy.
|
||||
func (um *upstreamMonitor) countHealthy(upstreams []string) int {
|
||||
var count int
|
||||
um.mu.RLock()
|
||||
for _, upstream := range upstreams {
|
||||
if !um.down[upstream] {
|
||||
count++
|
||||
}
|
||||
}
|
||||
um.mu.RUnlock()
|
||||
return count
|
||||
}
|
||||
|
||||
@@ -1,7 +1,13 @@
|
||||
package main
|
||||
|
||||
import "github.com/Control-D-Inc/ctrld/cmd/cli"
|
||||
import (
|
||||
"os"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld/cmd/cli"
|
||||
)
|
||||
|
||||
func main() {
|
||||
cli.Main()
|
||||
// make sure we exit with 0 if there are no errors
|
||||
os.Exit(0)
|
||||
}
|
||||
|
||||
@@ -28,15 +28,17 @@ type AppCallback interface {
|
||||
// Start configures utility with config.toml from provided directory.
|
||||
// This function will block until Stop is called
|
||||
// Check port availability prior to calling it.
|
||||
func (c *Controller) Start(CdUID string, HomeDir string, UpstreamProto string, logLevel int, logPath string) {
|
||||
func (c *Controller) Start(CdUID string, ProvisionID string, CustomHostname string, HomeDir string, UpstreamProto string, logLevel int, logPath string) {
|
||||
if c.stopCh == nil {
|
||||
c.stopCh = make(chan struct{})
|
||||
c.Config = cli.AppConfig{
|
||||
CdUID: CdUID,
|
||||
HomeDir: HomeDir,
|
||||
UpstreamProto: UpstreamProto,
|
||||
Verbose: logLevel,
|
||||
LogPath: logPath,
|
||||
CdUID: CdUID,
|
||||
ProvisionID: ProvisionID,
|
||||
CustomHostname: CustomHostname,
|
||||
HomeDir: HomeDir,
|
||||
UpstreamProto: UpstreamProto,
|
||||
Verbose: logLevel,
|
||||
LogPath: logPath,
|
||||
}
|
||||
appCallback := mapCallback(c.AppCallback)
|
||||
cli.RunMobile(&c.Config, &appCallback, c.stopCh)
|
||||
|
||||
133
config.go
133
config.go
@@ -53,10 +53,27 @@ const (
|
||||
FreeDnsDomain = "freedns.controld.com"
|
||||
// FreeDNSBoostrapIP is the IP address of freedns.controld.com.
|
||||
FreeDNSBoostrapIP = "76.76.2.11"
|
||||
// FreeDNSBoostrapIPv6 is the IPv6 address of freedns.controld.com.
|
||||
FreeDNSBoostrapIPv6 = "2606:1a40::11"
|
||||
// PremiumDnsDomain is the domain name of premium ControlD service.
|
||||
PremiumDnsDomain = "dns.controld.com"
|
||||
// PremiumDNSBoostrapIP is the IP address of dns.controld.com.
|
||||
PremiumDNSBoostrapIP = "76.76.2.22"
|
||||
// PremiumDNSBoostrapIPv6 is the IPv6 address of dns.controld.com.
|
||||
PremiumDNSBoostrapIPv6 = "2606:1a40::22"
|
||||
|
||||
// freeDnsDomainDev is the domain name of free ControlD service on dev env.
|
||||
freeDnsDomainDev = "freedns.controld.dev"
|
||||
// freeDNSBoostrapIP is the IP address of freedns.controld.dev.
|
||||
freeDNSBoostrapIP = "176.125.239.11"
|
||||
// freeDNSBoostrapIPv6 is the IPv6 address of freedns.controld.com.
|
||||
freeDNSBoostrapIPv6 = "2606:1a40:f000::11"
|
||||
// premiumDnsDomainDev is the domain name of premium ControlD service on dev env.
|
||||
premiumDnsDomainDev = "dns.controld.dev"
|
||||
// premiumDNSBoostrapIP is the IP address of dns.controld.dev.
|
||||
premiumDNSBoostrapIP = "176.125.239.22"
|
||||
// premiumDNSBoostrapIPv6 is the IPv6 address of dns.controld.dev.
|
||||
premiumDNSBoostrapIPv6 = "2606:1a40:f000::22"
|
||||
|
||||
controlDComDomain = "controld.com"
|
||||
controlDNetDomain = "controld.net"
|
||||
@@ -205,7 +222,7 @@ type ServiceConfig struct {
|
||||
CacheFlushDomains []string `mapstructure:"cache_flush_domains" toml:"cache_flush_domains" validate:"max=256"`
|
||||
MaxConcurrentRequests *int `mapstructure:"max_concurrent_requests" toml:"max_concurrent_requests,omitempty" validate:"omitempty,gte=0"`
|
||||
DHCPLeaseFile string `mapstructure:"dhcp_lease_file_path" toml:"dhcp_lease_file_path" validate:"omitempty,file"`
|
||||
DHCPLeaseFileFormat string `mapstructure:"dhcp_lease_file_format" toml:"dhcp_lease_file_format" validate:"required_unless=DHCPLeaseFile '',omitempty,oneof=dnsmasq isc-dhcp"`
|
||||
DHCPLeaseFileFormat string `mapstructure:"dhcp_lease_file_format" toml:"dhcp_lease_file_format" validate:"required_unless=DHCPLeaseFile '',omitempty,oneof=dnsmasq isc-dhcp kea-dhcp4"`
|
||||
DiscoverMDNS *bool `mapstructure:"discover_mdns" toml:"discover_mdns,omitempty"`
|
||||
DiscoverARP *bool `mapstructure:"discover_arp" toml:"discover_arp,omitempty"`
|
||||
DiscoverDHCP *bool `mapstructure:"discover_dhcp" toml:"discover_dhcp,omitempty"`
|
||||
@@ -261,6 +278,7 @@ type UpstreamConfig struct {
|
||||
http3RoundTripper6 http.RoundTripper
|
||||
certPool *x509.CertPool
|
||||
u *url.URL
|
||||
fallbackOnce sync.Once
|
||||
uid string
|
||||
}
|
||||
|
||||
@@ -340,6 +358,15 @@ func (uc *UpstreamConfig) Init() {
|
||||
}
|
||||
}
|
||||
|
||||
// VerifyMsg creates and returns a new DNS message could be used for testing upstream health.
|
||||
func (uc *UpstreamConfig) VerifyMsg() *dns.Msg {
|
||||
msg := new(dns.Msg)
|
||||
msg.RecursionDesired = true
|
||||
msg.SetQuestion(".", dns.TypeNS)
|
||||
msg.SetEdns0(4096, false) // ensure handling of large DNS response
|
||||
return msg
|
||||
}
|
||||
|
||||
// VerifyDomain returns the domain name that could be resolved by the upstream endpoint.
|
||||
// It returns empty for non-ControlD upstream endpoint.
|
||||
func (uc *UpstreamConfig) VerifyDomain() string {
|
||||
@@ -384,7 +411,7 @@ func (uc *UpstreamConfig) IsDiscoverable() bool {
|
||||
return *uc.Discoverable
|
||||
}
|
||||
switch uc.Type {
|
||||
case ResolverTypeOS, ResolverTypeLegacy, ResolverTypePrivate:
|
||||
case ResolverTypeOS, ResolverTypeLegacy, ResolverTypePrivate, ResolverTypeLocal:
|
||||
if ip, err := netip.ParseAddr(uc.Domain); err == nil {
|
||||
return ip.IsLoopback() || ip.IsPrivate() || ip.IsLinkLocalUnicast() || tsaddr.CGNATRange().Contains(ip)
|
||||
}
|
||||
@@ -402,12 +429,6 @@ func (uc *UpstreamConfig) SetCertPool(cp *x509.CertPool) {
|
||||
uc.certPool = cp
|
||||
}
|
||||
|
||||
// SetupBootstrapIP manually find all available IPs of the upstream.
|
||||
// The first usable IP will be used as bootstrap IP of the upstream.
|
||||
func (uc *UpstreamConfig) SetupBootstrapIP() {
|
||||
uc.setupBootstrapIP(true)
|
||||
}
|
||||
|
||||
// UID returns the unique identifier of the upstream.
|
||||
func (uc *UpstreamConfig) UID() string {
|
||||
return uc.uid
|
||||
@@ -415,11 +436,19 @@ func (uc *UpstreamConfig) UID() string {
|
||||
|
||||
// SetupBootstrapIP manually find all available IPs of the upstream.
|
||||
// The first usable IP will be used as bootstrap IP of the upstream.
|
||||
func (uc *UpstreamConfig) setupBootstrapIP(withBootstrapDNS bool) {
|
||||
// The upstream domain will be looked up using following orders:
|
||||
//
|
||||
// - Current system DNS settings.
|
||||
// - Direct IPs table for ControlD upstreams.
|
||||
// - ControlD Bootstrap DNS 76.76.2.22
|
||||
//
|
||||
// The setup process will block until there's usable IPs found.
|
||||
func (uc *UpstreamConfig) SetupBootstrapIP() {
|
||||
b := backoff.NewBackoff("setupBootstrapIP", func(format string, args ...any) {}, 10*time.Second)
|
||||
isControlD := uc.IsControlD()
|
||||
nss := initDefaultOsResolver()
|
||||
for {
|
||||
uc.bootstrapIPs = lookupIP(uc.Domain, uc.Timeout, withBootstrapDNS)
|
||||
uc.bootstrapIPs = lookupIP(uc.Domain, uc.Timeout, nss)
|
||||
// For ControlD upstream, the bootstrap IPs could not be RFC 1918 addresses,
|
||||
// filtering them out here to prevent weird behavior.
|
||||
if isControlD {
|
||||
@@ -432,6 +461,15 @@ func (uc *UpstreamConfig) setupBootstrapIP(withBootstrapDNS bool) {
|
||||
}
|
||||
}
|
||||
uc.bootstrapIPs = uc.bootstrapIPs[:n]
|
||||
if len(uc.bootstrapIPs) == 0 {
|
||||
uc.bootstrapIPs = bootstrapIPsFromControlDDomain(uc.Domain)
|
||||
ProxyLogger.Load().Warn().Msgf("no record found for %q, lookup from direct IP table", uc.Domain)
|
||||
}
|
||||
}
|
||||
if len(uc.bootstrapIPs) == 0 {
|
||||
ProxyLogger.Load().Warn().Msgf("no record found for %q, using bootstrap server: %s", uc.Domain, PremiumDNSBoostrapIP)
|
||||
uc.bootstrapIPs = lookupIP(uc.Domain, uc.Timeout, []string{net.JoinHostPort(PremiumDNSBoostrapIP, "53")})
|
||||
|
||||
}
|
||||
if len(uc.bootstrapIPs) > 0 {
|
||||
break
|
||||
@@ -458,7 +496,7 @@ func (uc *UpstreamConfig) ReBootstrap() {
|
||||
}
|
||||
_, _, _ = uc.g.Do("ReBootstrap", func() (any, error) {
|
||||
if uc.rebootstrap.CompareAndSwap(false, true) {
|
||||
ProxyLogger.Load().Debug().Msg("re-bootstrapping upstream ip")
|
||||
ProxyLogger.Load().Debug().Msgf("re-bootstrapping upstream ip for %v", uc)
|
||||
}
|
||||
return true, nil
|
||||
})
|
||||
@@ -485,7 +523,7 @@ func (uc *UpstreamConfig) setupDOHTransport() {
|
||||
uc.transport = uc.newDOHTransport(uc.bootstrapIPs6)
|
||||
case IpStackSplit:
|
||||
uc.transport4 = uc.newDOHTransport(uc.bootstrapIPs4)
|
||||
if hasIPv6() {
|
||||
if HasIPv6() {
|
||||
uc.transport6 = uc.newDOHTransport(uc.bootstrapIPs6)
|
||||
} else {
|
||||
uc.transport6 = uc.transport4
|
||||
@@ -529,7 +567,7 @@ func (uc *UpstreamConfig) newDOHTransport(addrs []string) *http.Transport {
|
||||
for i := range addrs {
|
||||
dialAddrs[i] = net.JoinHostPort(addrs[i], port)
|
||||
}
|
||||
conn, err := pd.DialContext(ctx, network, dialAddrs)
|
||||
conn, err := pd.DialContext(ctx, network, dialAddrs, ProxyLogger.Load())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -544,7 +582,10 @@ func (uc *UpstreamConfig) newDOHTransport(addrs []string) *http.Transport {
|
||||
|
||||
// Ping warms up the connection to DoH/DoH3 upstream.
|
||||
func (uc *UpstreamConfig) Ping() {
|
||||
_ = uc.ping()
|
||||
if err := uc.ping(); err != nil {
|
||||
ProxyLogger.Load().Debug().Err(err).Msgf("upstream ping failed: %s", uc.Endpoint)
|
||||
_ = uc.FallbackToDirectIP()
|
||||
}
|
||||
}
|
||||
|
||||
// ErrorPing is like Ping, but return an error if any.
|
||||
@@ -581,7 +622,6 @@ func (uc *UpstreamConfig) ping() error {
|
||||
for _, typ := range []uint16{dns.TypeA, dns.TypeAAAA} {
|
||||
switch uc.Type {
|
||||
case ResolverTypeDOH:
|
||||
|
||||
if err := ping(uc.dohTransport(typ)); err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -655,7 +695,7 @@ func (uc *UpstreamConfig) bootstrapIPForDNSType(dnsType uint16) string {
|
||||
case dns.TypeA:
|
||||
return pick(uc.bootstrapIPs4)
|
||||
default:
|
||||
if hasIPv6() {
|
||||
if HasIPv6() {
|
||||
return pick(uc.bootstrapIPs6)
|
||||
}
|
||||
return pick(uc.bootstrapIPs4)
|
||||
@@ -677,7 +717,7 @@ func (uc *UpstreamConfig) netForDNSType(dnsType uint16) (string, string) {
|
||||
case dns.TypeA:
|
||||
return "tcp4-tls", "udp4"
|
||||
default:
|
||||
if hasIPv6() {
|
||||
if HasIPv6() {
|
||||
return "tcp6-tls", "udp6"
|
||||
}
|
||||
return "tcp4-tls", "udp4"
|
||||
@@ -749,6 +789,41 @@ func (uc *UpstreamConfig) initDnsStamps() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Context returns a new context with timeout set from upstream config.
|
||||
func (uc *UpstreamConfig) Context(ctx context.Context) (context.Context, context.CancelFunc) {
|
||||
if uc.Timeout > 0 {
|
||||
return context.WithTimeout(ctx, time.Millisecond*time.Duration(uc.Timeout))
|
||||
}
|
||||
return context.WithCancel(ctx)
|
||||
}
|
||||
|
||||
// FallbackToDirectIP changes ControlD upstream endpoint to use direct IP instead of domain.
|
||||
func (uc *UpstreamConfig) FallbackToDirectIP() bool {
|
||||
if !uc.IsControlD() {
|
||||
return false
|
||||
}
|
||||
if uc.u == nil || uc.Domain == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
done := false
|
||||
uc.fallbackOnce.Do(func() {
|
||||
var ip string
|
||||
switch {
|
||||
case dns.IsSubDomain(PremiumDnsDomain, uc.Domain):
|
||||
ip = PremiumDNSBoostrapIP
|
||||
case dns.IsSubDomain(FreeDnsDomain, uc.Domain):
|
||||
ip = FreeDNSBoostrapIP
|
||||
default:
|
||||
return
|
||||
}
|
||||
ProxyLogger.Load().Warn().Msgf("using direct IP for %q: %s", uc.Endpoint, ip)
|
||||
uc.u.Host = ip
|
||||
done = true
|
||||
})
|
||||
return done
|
||||
}
|
||||
|
||||
// Init initialized necessary values for an ListenerConfig.
|
||||
func (lc *ListenerConfig) Init() {
|
||||
if lc.Policy != nil {
|
||||
@@ -886,3 +961,27 @@ func upstreamUID() string {
|
||||
return hex.EncodeToString(b)
|
||||
}
|
||||
}
|
||||
|
||||
// String returns a string representation of the UpstreamConfig for logging.
|
||||
func (uc *UpstreamConfig) String() string {
|
||||
if uc == nil {
|
||||
return "<nil>"
|
||||
}
|
||||
return fmt.Sprintf("{name: %q, type: %q, endpoint: %q, bootstrap_ip: %q, domain: %q, ip_stack: %q}",
|
||||
uc.Name, uc.Type, uc.Endpoint, uc.BootstrapIP, uc.Domain, uc.IPStack)
|
||||
}
|
||||
|
||||
// bootstrapIPsFromControlDDomain returns bootstrap IPs for ControlD domain.
|
||||
func bootstrapIPsFromControlDDomain(domain string) []string {
|
||||
switch {
|
||||
case dns.IsSubDomain(PremiumDnsDomain, domain):
|
||||
return []string{PremiumDNSBoostrapIP, PremiumDNSBoostrapIPv6}
|
||||
case dns.IsSubDomain(FreeDnsDomain, domain):
|
||||
return []string{FreeDNSBoostrapIP, FreeDNSBoostrapIPv6}
|
||||
case dns.IsSubDomain(premiumDnsDomainDev, domain):
|
||||
return []string{premiumDNSBoostrapIP, premiumDNSBoostrapIPv6}
|
||||
case dns.IsSubDomain(freeDnsDomainDev, domain):
|
||||
return []string{freeDNSBoostrapIP, freeDNSBoostrapIPv6}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -8,19 +8,43 @@ import (
|
||||
)
|
||||
|
||||
func TestUpstreamConfig_SetupBootstrapIP(t *testing.T) {
|
||||
uc := &UpstreamConfig{
|
||||
Name: "test",
|
||||
Type: ResolverTypeDOH,
|
||||
Endpoint: "https://freedns.controld.com/p2",
|
||||
Timeout: 5000,
|
||||
tests := []struct {
|
||||
name string
|
||||
uc *UpstreamConfig
|
||||
}{
|
||||
{
|
||||
name: "doh/doh3",
|
||||
uc: &UpstreamConfig{
|
||||
Name: "doh",
|
||||
Type: ResolverTypeDOH,
|
||||
Endpoint: "https://freedns.controld.com/p2",
|
||||
Timeout: 5000,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "doq/dot",
|
||||
uc: &UpstreamConfig{
|
||||
Name: "dot",
|
||||
Type: ResolverTypeDOT,
|
||||
Endpoint: "p2.freedns.controld.com",
|
||||
Timeout: 5000,
|
||||
},
|
||||
},
|
||||
}
|
||||
uc.Init()
|
||||
uc.setupBootstrapIP(false)
|
||||
if len(uc.bootstrapIPs) == 0 {
|
||||
t.Log(defaultNameservers())
|
||||
t.Fatal("could not bootstrap ip without bootstrap DNS")
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// Enable parallel tests once https://github.com/microsoft/wmi/issues/165 fixed.
|
||||
// t.Parallel()
|
||||
tc.uc.Init()
|
||||
tc.uc.SetupBootstrapIP()
|
||||
if len(tc.uc.bootstrapIPs) == 0 {
|
||||
t.Log(defaultNameservers())
|
||||
t.Fatalf("could not bootstrap ip: %s", tc.uc.String())
|
||||
}
|
||||
})
|
||||
}
|
||||
t.Log(uc)
|
||||
|
||||
}
|
||||
|
||||
func TestUpstreamConfig_Init(t *testing.T) {
|
||||
|
||||
@@ -24,7 +24,7 @@ func (uc *UpstreamConfig) setupDOH3Transport() {
|
||||
uc.http3RoundTripper = uc.newDOH3Transport(uc.bootstrapIPs6)
|
||||
case IpStackSplit:
|
||||
uc.http3RoundTripper4 = uc.newDOH3Transport(uc.bootstrapIPs4)
|
||||
if hasIPv6() {
|
||||
if HasIPv6() {
|
||||
uc.http3RoundTripper6 = uc.newDOH3Transport(uc.bootstrapIPs6)
|
||||
} else {
|
||||
uc.http3RoundTripper6 = uc.http3RoundTripper4
|
||||
@@ -34,9 +34,9 @@ func (uc *UpstreamConfig) setupDOH3Transport() {
|
||||
}
|
||||
|
||||
func (uc *UpstreamConfig) newDOH3Transport(addrs []string) http.RoundTripper {
|
||||
rt := &http3.RoundTripper{}
|
||||
rt := &http3.Transport{}
|
||||
rt.TLSClientConfig = &tls.Config{RootCAs: uc.certPool}
|
||||
rt.Dial = func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) {
|
||||
rt.Dial = func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (*quic.Conn, error) {
|
||||
_, port, _ := net.SplitHostPort(addr)
|
||||
// if we have a bootstrap ip set, use it to avoid DNS lookup
|
||||
if uc.BootstrapIP != "" {
|
||||
@@ -64,7 +64,7 @@ func (uc *UpstreamConfig) newDOH3Transport(addrs []string) http.RoundTripper {
|
||||
ProxyLogger.Load().Debug().Msgf("sending doh3 request to: %s", conn.RemoteAddr())
|
||||
return conn, err
|
||||
}
|
||||
runtime.SetFinalizer(rt, func(rt *http3.RoundTripper) {
|
||||
runtime.SetFinalizer(rt, func(rt *http3.Transport) {
|
||||
rt.CloseIdleConnections()
|
||||
})
|
||||
return rt
|
||||
@@ -96,14 +96,14 @@ func (uc *UpstreamConfig) doh3Transport(dnsType uint16) http.RoundTripper {
|
||||
// - quic dialer is different with net.Dialer
|
||||
// - simplification for quic free version
|
||||
type parallelDialerResult struct {
|
||||
conn quic.EarlyConnection
|
||||
conn *quic.Conn
|
||||
err error
|
||||
}
|
||||
|
||||
type quicParallelDialer struct{}
|
||||
|
||||
// Dial performs parallel dialing to the given address list.
|
||||
func (d *quicParallelDialer) Dial(ctx context.Context, addrs []string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) {
|
||||
func (d *quicParallelDialer) Dial(ctx context.Context, addrs []string, tlsCfg *tls.Config, cfg *quic.Config) (*quic.Conn, error) {
|
||||
if len(addrs) == 0 {
|
||||
return nil, errors.New("empty addresses")
|
||||
}
|
||||
|
||||
@@ -111,6 +111,7 @@ func TestConfigValidation(t *testing.T) {
|
||||
{"doh3 endpoint without type", doh3UpstreamEndpointWithoutType(t), false},
|
||||
{"sdns endpoint without type", sdnsUpstreamEndpointWithoutType(t), false},
|
||||
{"maximum number of flush cache domains", configWithInvalidFlushCacheDomain(t), true},
|
||||
{"kea dhcp4 format", configWithDhcp4KeaFormat(t), false},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
@@ -307,6 +308,12 @@ func configWithInvalidLeaseFileFormat(t *testing.T) *ctrld.Config {
|
||||
return cfg
|
||||
}
|
||||
|
||||
func configWithDhcp4KeaFormat(t *testing.T) *ctrld.Config {
|
||||
cfg := defaultConfig(t)
|
||||
cfg.Service.DHCPLeaseFileFormat = "kea-dhcp4"
|
||||
return cfg
|
||||
}
|
||||
|
||||
func configWithInvalidDoHEndpoint(t *testing.T) *ctrld.Config {
|
||||
cfg := defaultConfig(t)
|
||||
cfg.Upstream["0"].Endpoint = "/1.1.1.1"
|
||||
|
||||
7
desktop_darwin.go
Normal file
7
desktop_darwin.go
Normal file
@@ -0,0 +1,7 @@
|
||||
package ctrld
|
||||
|
||||
// IsDesktopPlatform indicates if ctrld is running on a desktop platform,
|
||||
// currently defined as macOS or Windows workstation.
|
||||
func IsDesktopPlatform() bool {
|
||||
return true
|
||||
}
|
||||
9
desktop_others.go
Normal file
9
desktop_others.go
Normal file
@@ -0,0 +1,9 @@
|
||||
//go:build !windows && !darwin
|
||||
|
||||
package ctrld
|
||||
|
||||
// IsDesktopPlatform indicates if ctrld is running on a desktop platform,
|
||||
// currently defined as macOS or Windows workstation.
|
||||
func IsDesktopPlatform() bool {
|
||||
return false
|
||||
}
|
||||
7
desktop_windows.go
Normal file
7
desktop_windows.go
Normal file
@@ -0,0 +1,7 @@
|
||||
package ctrld
|
||||
|
||||
// IsDesktopPlatform indicates if ctrld is running on a desktop platform,
|
||||
// currently defined as macOS or Windows workstation.
|
||||
func IsDesktopPlatform() bool {
|
||||
return isWindowsWorkStation()
|
||||
}
|
||||
30
dns.go
Normal file
30
dns.go
Normal file
@@ -0,0 +1,30 @@
|
||||
package ctrld
|
||||
|
||||
import (
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
// SetCacheReply extracts and stores the necessary data from the message for a cached answer.
|
||||
func SetCacheReply(answer, msg *dns.Msg, code int) {
|
||||
answer.SetRcode(msg, code)
|
||||
cCookie := getEdns0Cookie(msg.IsEdns0())
|
||||
sCookie := getEdns0Cookie(answer.IsEdns0())
|
||||
if cCookie != nil && sCookie != nil {
|
||||
// Client cookie is fixed size 8 bytes, Server cookie is variable size 8 -> 32 bytes.
|
||||
// See https://datatracker.ietf.org/doc/html/rfc7873#section-4
|
||||
sCookie.Cookie = cCookie.Cookie[:16] + sCookie.Cookie[16:]
|
||||
}
|
||||
}
|
||||
|
||||
// getEdns0Cookie returns Edns0 cookie from *dns.OPT if present.
|
||||
func getEdns0Cookie(opt *dns.OPT) *dns.EDNS0_COOKIE {
|
||||
if opt == nil {
|
||||
return nil
|
||||
}
|
||||
for _, o := range opt.Option {
|
||||
if e, ok := o.(*dns.EDNS0_COOKIE); ok {
|
||||
return e
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -166,7 +166,6 @@ before serving the query.
|
||||
|
||||
### max_concurrent_requests
|
||||
The number of concurrent requests that will be handled, must be a non-negative integer.
|
||||
Tweaking this value depends on the capacity of your system.
|
||||
|
||||
- Type: number
|
||||
- Required: no
|
||||
@@ -179,6 +178,8 @@ Perform LAN client discovery using mDNS. This will spawn a listener on port 5353
|
||||
- Required: no
|
||||
- Default: true
|
||||
|
||||
This config is ignored, and always set to `false` on Windows Desktop and Macos.
|
||||
|
||||
### discover_arp
|
||||
Perform LAN client discovery using ARP.
|
||||
|
||||
@@ -186,6 +187,8 @@ Perform LAN client discovery using ARP.
|
||||
- Required: no
|
||||
- Default: true
|
||||
|
||||
This config is ignored, and always set to `false` on Windows Desktop and Macos.
|
||||
|
||||
### discover_dhcp
|
||||
Perform LAN client discovery using DHCP leases files. Common file locations are auto-discovered.
|
||||
|
||||
@@ -193,6 +196,8 @@ Perform LAN client discovery using DHCP leases files. Common file locations are
|
||||
- Required: no
|
||||
- Default: true
|
||||
|
||||
This config is ignored, and always set to `false` on Windows Desktop and Macos.
|
||||
|
||||
### discover_ptr
|
||||
Perform LAN client discovery using PTR queries.
|
||||
|
||||
@@ -200,6 +205,8 @@ Perform LAN client discovery using PTR queries.
|
||||
- Required: no
|
||||
- Default: true
|
||||
|
||||
This config is ignored, and always set to `false` on Windows Desktop and Macos.
|
||||
|
||||
### discover_hosts
|
||||
Perform LAN client discovery using hosts file.
|
||||
|
||||
@@ -207,6 +214,8 @@ Perform LAN client discovery using hosts file.
|
||||
- Required: no
|
||||
- Default: true
|
||||
|
||||
This config is ignored, and always set to `false` on Windows Desktop and Macos.
|
||||
|
||||
### discover_refresh_interval
|
||||
Time in seconds between each discovery refresh loop to update new client information data.
|
||||
The default value is 120 seconds, lower this value to make the discovery process run more aggressively.
|
||||
@@ -253,9 +262,7 @@ Specifying the `ip` and `port` of the Prometheus metrics server. The Prometheus
|
||||
- Default: ""
|
||||
|
||||
### dns_watchdog_enabled
|
||||
Checking DNS changes to network interfaces and reverting to ctrld's own settings.
|
||||
|
||||
The DNS watchdog process only runs on Windows and MacOS.
|
||||
Watches all physical interfaces for DNS changes and reverts them to ctrld's settings.The DNS watchdog process only runs on Windows and MacOS.
|
||||
|
||||
- Type: boolean
|
||||
- Required: no
|
||||
@@ -274,7 +281,7 @@ If the time duration is non-positive, default value will be used.
|
||||
- Default: 20s
|
||||
|
||||
### refetch_time
|
||||
Time in seconds between each iteration that reloads custom config if changed.
|
||||
Time in seconds between each iteration that reloads custom config from the API.
|
||||
|
||||
The value must be a positive number, any invalid value will be ignored and default value will be used.
|
||||
- Type: number
|
||||
@@ -282,7 +289,7 @@ The value must be a positive number, any invalid value will be ignored and defau
|
||||
- Default: 3600
|
||||
|
||||
### leak_on_upstream_failure
|
||||
Once ctrld is "offline", mean ctrld could not connect to any upstream, next queries will be leaked to OS resolver.
|
||||
If a remote upstream fails to resolve a query or is unreachable, `ctrld` will forward the queries to the default DNS resolver on the network. If failures persist, `ctrld` will remove itself from all networking interfaces until connectivity is restored.
|
||||
|
||||
- Type: boolean
|
||||
- Required: no
|
||||
@@ -531,6 +538,15 @@ rules = [
|
||||
]
|
||||
```
|
||||
|
||||
If there is no explicitly defined rules, LAN queries will be handled solely by the OS resolver.
|
||||
|
||||
These following domains are considered LAN queries:
|
||||
|
||||
- Queries does not have dot `.` in domain name, like `machine1`, `example`, ... (1)
|
||||
- Queries have domain ends with: `.domain`, `.lan`, `.local`. (2)
|
||||
- All `SRV` queries of LAN hostname (1) + (2).
|
||||
- `PTR` queries with private IPs.
|
||||
|
||||
---
|
||||
|
||||
Note that the order of matching preference:
|
||||
|
||||
42
docs/known-issues.md
Normal file
42
docs/known-issues.md
Normal file
@@ -0,0 +1,42 @@
|
||||
# Known Issues
|
||||
|
||||
This document outlines known issues with ctrld and their current status, workarounds, and recommendations.
|
||||
|
||||
## macOS (Darwin) Issues
|
||||
|
||||
### Self-Upgrade Issue on Darwin 15.5
|
||||
|
||||
**Issue**: ctrld self-upgrading functionality may not work on macOS Darwin 15.5.
|
||||
|
||||
**Status**: Under investigation
|
||||
|
||||
**Description**: Users on macOS Darwin 15.5 may experience issues when ctrld attempts to perform automatic self-upgrades. The upgrade process would be triggered, but ctrld won't be upgraded.
|
||||
|
||||
**Workarounds**:
|
||||
1. **Recommended**: Upgrade your macOS system to Darwin 15.6 or later, which has been tested and verified to work correctly with ctrld self-upgrade functionality.
|
||||
2. **Alternative**: Run `ctrld upgrade prod` directly to manually upgrade ctrld to the latest version on Darwin 15.5.
|
||||
|
||||
**Affected Versions**: ctrld v1.4.2 and later on macOS Darwin 15.5
|
||||
|
||||
**Last Updated**: 05/09/2025
|
||||
|
||||
---
|
||||
|
||||
## Contributing to Known Issues
|
||||
|
||||
If you encounter an issue not listed here, please:
|
||||
|
||||
1. Check the [GitHub Issues](https://github.com/Control-D-Inc/ctrld/issues) to see if it's already reported
|
||||
2. If not reported, create a new issue with:
|
||||
- Detailed description of the problem
|
||||
- Steps to reproduce
|
||||
- Expected vs actual behavior
|
||||
- System information (OS, version, architecture)
|
||||
- ctrld version
|
||||
|
||||
## Issue Status Legend
|
||||
|
||||
- **Under investigation**: Issue is confirmed and being analyzed
|
||||
- **Workaround available**: Temporary solution exists while permanent fix is developed
|
||||
- **Fixed**: Issue has been resolved in a specific version
|
||||
- **Won't fix**: Issue is acknowledged but will not be addressed due to technical limitations or design decisions
|
||||
57
doh.go
57
doh.go
@@ -2,6 +2,7 @@ package ctrld
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
@@ -113,7 +114,14 @@ func (r *dohResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, erro
|
||||
c.Transport = transport
|
||||
}
|
||||
resp, err := c.Do(req)
|
||||
if err != nil && r.uc.FallbackToDirectIP() {
|
||||
retryCtx, cancel := r.uc.Context(context.WithoutCancel(ctx))
|
||||
defer cancel()
|
||||
Log(ctx, ProxyLogger.Load().Warn().Err(err), "retrying request after fallback to direct ip")
|
||||
resp, err = c.Do(req.Clone(retryCtx))
|
||||
}
|
||||
if err != nil {
|
||||
err = wrapUrlError(err)
|
||||
if r.isDoH3 {
|
||||
if closer, ok := c.Transport.(io.Closer); ok {
|
||||
closer.Close()
|
||||
@@ -202,3 +210,52 @@ func newNextDNSHeaders(ci *ClientInfo) http.Header {
|
||||
}
|
||||
return header
|
||||
}
|
||||
|
||||
// wrapCertificateVerificationError wraps a certificate verification error with additional context about the certificate issuer.
|
||||
// It extracts information like the issuer, organization, and subject from the certificate for a more descriptive error output.
|
||||
// If no certificate-related information is available, it simply returns the original error unmodified.
|
||||
func wrapCertificateVerificationError(err error) error {
|
||||
var tlsErr *tls.CertificateVerificationError
|
||||
if errors.As(err, &tlsErr) {
|
||||
if len(tlsErr.UnverifiedCertificates) > 0 {
|
||||
cert := tlsErr.UnverifiedCertificates[0]
|
||||
// Extract a more user-friendly issuer name
|
||||
var issuer string
|
||||
var organization string
|
||||
if len(cert.Issuer.Organization) > 0 {
|
||||
organization = cert.Issuer.Organization[0]
|
||||
issuer = organization
|
||||
} else if cert.Issuer.CommonName != "" {
|
||||
issuer = cert.Issuer.CommonName
|
||||
} else {
|
||||
issuer = cert.Issuer.String()
|
||||
}
|
||||
|
||||
// Get the organization from the subject field as well
|
||||
if len(cert.Subject.Organization) > 0 {
|
||||
organization = cert.Subject.Organization[0]
|
||||
}
|
||||
|
||||
// Extract the subject information
|
||||
subjectCN := cert.Subject.CommonName
|
||||
if subjectCN == "" && len(cert.Subject.Organization) > 0 {
|
||||
subjectCN = cert.Subject.Organization[0]
|
||||
}
|
||||
return fmt.Errorf("%w: %s, %s, %s", tlsErr, subjectCN, organization, issuer)
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// wrapUrlError inspects and wraps a URL error, focusing on certificate verification errors for detailed context.
|
||||
func wrapUrlError(err error) error {
|
||||
var urlErr *url.Error
|
||||
if errors.As(err, &urlErr) {
|
||||
var tlsErr *tls.CertificateVerificationError
|
||||
if errors.As(urlErr.Err, &tlsErr) {
|
||||
urlErr.Err = wrapCertificateVerificationError(tlsErr)
|
||||
return urlErr
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
243
doh_test.go
243
doh_test.go
@@ -1,8 +1,22 @@
|
||||
package ctrld
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"errors"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"runtime"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"github.com/quic-go/quic-go/http3"
|
||||
)
|
||||
|
||||
func Test_dohOsHeaderValue(t *testing.T) {
|
||||
@@ -21,3 +35,232 @@ func Test_dohOsHeaderValue(t *testing.T) {
|
||||
t.Fatalf("missing decoding value for: %q", runtime.GOOS)
|
||||
}
|
||||
}
|
||||
|
||||
func Test_wrapUrlError(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
err error
|
||||
wantErr string
|
||||
}{
|
||||
{
|
||||
name: "No wrapping for non-URL errors",
|
||||
err: errors.New("plain error"),
|
||||
wantErr: "plain error",
|
||||
},
|
||||
{
|
||||
name: "URL error without TLS error",
|
||||
err: &url.Error{
|
||||
Op: "Get",
|
||||
URL: "https://example.com",
|
||||
Err: errors.New("underlying error"),
|
||||
},
|
||||
wantErr: "Get \"https://example.com\": underlying error",
|
||||
},
|
||||
{
|
||||
name: "TLS error with missing unverified certificate data",
|
||||
err: &url.Error{
|
||||
Op: "Get",
|
||||
URL: "https://example.com",
|
||||
Err: &tls.CertificateVerificationError{
|
||||
UnverifiedCertificates: nil,
|
||||
Err: &x509.UnknownAuthorityError{},
|
||||
},
|
||||
},
|
||||
wantErr: `Get "https://example.com": tls: failed to verify certificate: x509: certificate signed by unknown authority`,
|
||||
},
|
||||
{
|
||||
name: "TLS error with valid certificate data",
|
||||
err: &url.Error{
|
||||
Op: "Get",
|
||||
URL: "https://example.com",
|
||||
Err: &tls.CertificateVerificationError{
|
||||
UnverifiedCertificates: []*x509.Certificate{
|
||||
{
|
||||
Subject: pkix.Name{
|
||||
CommonName: "BadSubjectCN",
|
||||
Organization: []string{"BadSubjectOrg"},
|
||||
},
|
||||
Issuer: pkix.Name{
|
||||
CommonName: "BadIssuerCN",
|
||||
Organization: []string{"BadIssuerOrg"},
|
||||
},
|
||||
},
|
||||
},
|
||||
Err: &x509.UnknownAuthorityError{},
|
||||
},
|
||||
},
|
||||
wantErr: `Get "https://example.com": tls: failed to verify certificate: x509: certificate signed by unknown authority: BadSubjectCN, BadSubjectOrg, BadIssuerOrg`,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
gotErr := wrapUrlError(tt.err)
|
||||
if gotErr.Error() != tt.wantErr {
|
||||
t.Errorf("wrapCertificateVerificationError() error = %v, want %v", gotErr, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_ClientCertificateVerificationError(t *testing.T) {
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/dns-message")
|
||||
})
|
||||
tlsServer, cert := testTLSServer(t, handler)
|
||||
tlsServerUrl, err := url.Parse(tlsServer.URL)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
quicServer := newTestQUICServer(t)
|
||||
http3Server := newTestHTTP3Server(t, handler)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
uc *UpstreamConfig
|
||||
}{
|
||||
{
|
||||
"doh",
|
||||
&UpstreamConfig{
|
||||
Name: "doh",
|
||||
Type: ResolverTypeDOH,
|
||||
Endpoint: tlsServer.URL,
|
||||
Timeout: 1000,
|
||||
},
|
||||
},
|
||||
{
|
||||
"doh3",
|
||||
&UpstreamConfig{
|
||||
Name: "doh3",
|
||||
Type: ResolverTypeDOH3,
|
||||
Endpoint: http3Server.addr,
|
||||
Timeout: 5000,
|
||||
},
|
||||
},
|
||||
{
|
||||
"doq",
|
||||
&UpstreamConfig{
|
||||
Name: "doq",
|
||||
Type: ResolverTypeDOQ,
|
||||
Endpoint: quicServer.addr,
|
||||
Timeout: 5000,
|
||||
},
|
||||
},
|
||||
{
|
||||
"dot",
|
||||
&UpstreamConfig{
|
||||
Name: "dot",
|
||||
Type: ResolverTypeDOT,
|
||||
Endpoint: net.JoinHostPort(tlsServerUrl.Hostname(), tlsServerUrl.Port()),
|
||||
Timeout: 1000,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
tc.uc.Init()
|
||||
tc.uc.SetupBootstrapIP()
|
||||
r, err := NewResolver(tc.uc)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
msg := new(dns.Msg)
|
||||
msg.SetQuestion("verify.controld.com.", dns.TypeA)
|
||||
msg.RecursionDesired = true
|
||||
_, err = r.Resolve(context.Background(), msg)
|
||||
// Verify the error contains the expected certificate information
|
||||
if err == nil {
|
||||
t.Fatal("expected certificate verification error, got nil")
|
||||
}
|
||||
|
||||
// You can check the error contains information about the test certificate
|
||||
if !strings.Contains(err.Error(), cert.Issuer.CommonName) {
|
||||
t.Fatalf("error should contain issuer information %q, got: %v", cert.Issuer.CommonName, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// testTLSServer creates an HTTPS test server with a self-signed certificate
|
||||
// returns the server and its certificate for verification testing
|
||||
// testTLSServer creates an HTTPS test server with a self-signed certificate
|
||||
func testTLSServer(t *testing.T, handler http.Handler) (*httptest.Server, *x509.Certificate) {
|
||||
t.Helper()
|
||||
|
||||
testCert := generateTestCertificate(t)
|
||||
|
||||
// Create a test server
|
||||
server := httptest.NewUnstartedServer(handler)
|
||||
server.TLS = &tls.Config{
|
||||
Certificates: []tls.Certificate{testCert.tlsCert},
|
||||
}
|
||||
server.StartTLS()
|
||||
|
||||
// Add cleanup
|
||||
t.Cleanup(server.Close)
|
||||
|
||||
return server, testCert.cert
|
||||
}
|
||||
|
||||
// testHTTP3Server represents a structure for an HTTP/3 test server with its server instance, TLS certificate, and address.
|
||||
type testHTTP3Server struct {
|
||||
server *http3.Server
|
||||
cert *x509.Certificate
|
||||
addr string
|
||||
}
|
||||
|
||||
// newTestHTTP3Server creates and starts a test HTTP/3 server with a given handler and returns the server instance.
|
||||
func newTestHTTP3Server(t *testing.T, handler http.Handler) *testHTTP3Server {
|
||||
t.Helper()
|
||||
|
||||
testCert := generateTestCertificate(t)
|
||||
|
||||
// First create a listener to get the actual port
|
||||
udpAddr := &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0}
|
||||
udpConn, err := net.ListenUDP("udp", udpAddr)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create UDP listener: %v", err)
|
||||
}
|
||||
|
||||
// Get the actual address
|
||||
actualAddr := udpConn.LocalAddr().String()
|
||||
|
||||
// Create TLS config
|
||||
tlsConfig := &tls.Config{
|
||||
Certificates: []tls.Certificate{testCert.tlsCert},
|
||||
NextProtos: []string{"h3"}, // HTTP/3 protocol identifier
|
||||
}
|
||||
|
||||
// Create HTTP/3 server
|
||||
server := &http3.Server{
|
||||
Handler: handler,
|
||||
TLSConfig: tlsConfig,
|
||||
}
|
||||
|
||||
// Start the server with the existing UDP connection
|
||||
go func() {
|
||||
if err := server.Serve(udpConn); err != nil && !errors.Is(err, http.ErrServerClosed) {
|
||||
t.Logf("HTTP/3 server error: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
h3Server := &testHTTP3Server{
|
||||
server: server,
|
||||
cert: testCert.cert,
|
||||
addr: actualAddr,
|
||||
}
|
||||
|
||||
// Add cleanup
|
||||
t.Cleanup(func() {
|
||||
server.Close()
|
||||
udpConn.Close()
|
||||
})
|
||||
|
||||
// Wait a bit for the server to be ready
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
return h3Server
|
||||
}
|
||||
|
||||
2
doq.go
2
doq.go
@@ -43,7 +43,7 @@ func resolve(ctx context.Context, msg *dns.Msg, endpoint string, tlsConfig *tls.
|
||||
continue
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, wrapCertificateVerificationError(err)
|
||||
}
|
||||
return answer, nil
|
||||
}
|
||||
|
||||
223
doq_test.go
Normal file
223
doq_test.go
Normal file
@@ -0,0 +1,223 @@
|
||||
// test_helpers.go
|
||||
package ctrld
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"math/big"
|
||||
"net"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"github.com/quic-go/quic-go"
|
||||
)
|
||||
|
||||
// testCertificate represents a test certificate with its components
|
||||
type testCertificate struct {
|
||||
cert *x509.Certificate
|
||||
tlsCert tls.Certificate
|
||||
template *x509.Certificate
|
||||
}
|
||||
|
||||
// generateTestCertificate creates a self-signed certificate for testing
|
||||
func generateTestCertificate(t *testing.T) *testCertificate {
|
||||
t.Helper()
|
||||
|
||||
// Generate private key
|
||||
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to generate private key: %v", err)
|
||||
}
|
||||
|
||||
// Create certificate template
|
||||
template := &x509.Certificate{
|
||||
SerialNumber: big.NewInt(1),
|
||||
Subject: pkix.Name{
|
||||
Organization: []string{"Test Org"},
|
||||
CommonName: "Test CA",
|
||||
},
|
||||
Issuer: pkix.Name{
|
||||
Organization: []string{"Test Issuer Org"},
|
||||
CommonName: "Test Issuer CA",
|
||||
},
|
||||
NotBefore: time.Now(),
|
||||
NotAfter: time.Now().Add(time.Hour),
|
||||
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
|
||||
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
|
||||
BasicConstraintsValid: true,
|
||||
IPAddresses: []net.IP{net.ParseIP("127.0.0.1")},
|
||||
DNSNames: []string{"localhost"},
|
||||
}
|
||||
|
||||
// Create certificate
|
||||
derBytes, err := x509.CreateCertificate(rand.Reader, template, template, &privateKey.PublicKey, privateKey)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create certificate: %v", err)
|
||||
}
|
||||
|
||||
cert, err := x509.ParseCertificate(derBytes)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to parse certificate: %v", err)
|
||||
}
|
||||
|
||||
// Create TLS certificate
|
||||
tlsCert := tls.Certificate{
|
||||
Certificate: [][]byte{derBytes},
|
||||
PrivateKey: privateKey,
|
||||
}
|
||||
|
||||
return &testCertificate{
|
||||
cert: cert,
|
||||
tlsCert: tlsCert,
|
||||
template: template,
|
||||
}
|
||||
}
|
||||
|
||||
// testQUICServer is a structure representing a test QUIC server for handling connections and streams.
|
||||
// listener is the QUIC listener used to accept incoming connections.
|
||||
// cert is the x509 certificate used by the server for authentication.
|
||||
// addr is the address on which the test server is running.
|
||||
type testQUICServer struct {
|
||||
listener *quic.Listener
|
||||
cert *x509.Certificate
|
||||
addr string
|
||||
}
|
||||
|
||||
// newTestQUICServer creates and initializes a test QUIC server with TLS configuration and starts accepting connections.
|
||||
func newTestQUICServer(t *testing.T) *testQUICServer {
|
||||
t.Helper()
|
||||
|
||||
testCert := generateTestCertificate(t)
|
||||
|
||||
// Create TLS config
|
||||
tlsConfig := &tls.Config{
|
||||
Certificates: []tls.Certificate{testCert.tlsCert},
|
||||
NextProtos: []string{"doq"},
|
||||
}
|
||||
|
||||
// Create QUIC listener
|
||||
listener, err := quic.ListenAddr("127.0.0.1:0", tlsConfig, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create QUIC listener: %v", err)
|
||||
}
|
||||
|
||||
server := &testQUICServer{
|
||||
listener: listener,
|
||||
cert: testCert.cert,
|
||||
addr: listener.Addr().String(),
|
||||
}
|
||||
|
||||
// Start handling connections
|
||||
go server.serve(t)
|
||||
|
||||
// Add cleanup
|
||||
t.Cleanup(func() {
|
||||
listener.Close()
|
||||
})
|
||||
|
||||
return server
|
||||
}
|
||||
|
||||
// serve handles incoming connections on the QUIC listener and delegates them to connection handlers in separate goroutines.
|
||||
func (s *testQUICServer) serve(t *testing.T) {
|
||||
for {
|
||||
conn, err := s.listener.Accept(context.Background())
|
||||
if err != nil {
|
||||
// Check if the error is due to the listener being closed
|
||||
if strings.Contains(err.Error(), "server closed") {
|
||||
return
|
||||
}
|
||||
t.Logf("failed to accept connection: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
go s.handleConnection(t, conn)
|
||||
}
|
||||
}
|
||||
|
||||
// handleConnection manages an individual QUIC connection by accepting and handling incoming streams in separate goroutines.
|
||||
func (s *testQUICServer) handleConnection(t *testing.T, conn *quic.Conn) {
|
||||
for {
|
||||
stream, err := conn.AcceptStream(context.Background())
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
go s.handleStream(t, stream)
|
||||
}
|
||||
}
|
||||
|
||||
// handleStream processes a single QUIC stream, reads DNS messages, generates a response, and sends it back to the client.
|
||||
func (s *testQUICServer) handleStream(t *testing.T, stream *quic.Stream) {
|
||||
defer stream.Close()
|
||||
|
||||
// Read length (2 bytes)
|
||||
lenBuf := make([]byte, 2)
|
||||
_, err := stream.Read(lenBuf)
|
||||
if err != nil {
|
||||
t.Logf("failed to read message length: %v", err)
|
||||
return
|
||||
}
|
||||
msgLen := uint16(lenBuf[0])<<8 | uint16(lenBuf[1])
|
||||
|
||||
// Read message
|
||||
msgBuf := make([]byte, msgLen)
|
||||
_, err = stream.Read(msgBuf)
|
||||
if err != nil {
|
||||
t.Logf("failed to read message: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Parse DNS message
|
||||
msg := new(dns.Msg)
|
||||
if err := msg.Unpack(msgBuf); err != nil {
|
||||
t.Logf("failed to unpack DNS message: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Create response
|
||||
response := new(dns.Msg)
|
||||
response.SetReply(msg)
|
||||
response.Authoritative = true
|
||||
|
||||
// Add a test answer
|
||||
if len(msg.Question) > 0 && msg.Question[0].Qtype == dns.TypeA {
|
||||
response.Answer = append(response.Answer, &dns.A{
|
||||
Hdr: dns.RR_Header{
|
||||
Name: msg.Question[0].Name,
|
||||
Rrtype: dns.TypeA,
|
||||
Class: dns.ClassINET,
|
||||
Ttl: 300,
|
||||
},
|
||||
A: net.ParseIP("192.0.2.1"), // TEST-NET-1 address
|
||||
})
|
||||
}
|
||||
|
||||
// Pack response
|
||||
respBytes, err := response.Pack()
|
||||
if err != nil {
|
||||
t.Logf("failed to pack response: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Write length
|
||||
respLen := uint16(len(respBytes))
|
||||
_, err = stream.Write([]byte{byte(respLen >> 8), byte(respLen & 0xFF)})
|
||||
if err != nil {
|
||||
t.Logf("failed to write response length: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Write response
|
||||
_, err = stream.Write(respBytes)
|
||||
if err != nil {
|
||||
t.Logf("failed to write response: %v", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
5
dot.go
5
dot.go
@@ -18,12 +18,11 @@ func (r *dotResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, erro
|
||||
// dns.controld.dev first. By using a dialer with custom resolver,
|
||||
// we ensure that we can always resolve the bootstrap domain
|
||||
// regardless of the machine DNS status.
|
||||
dialer := newDialer(net.JoinHostPort(controldBootstrapDns, "53"))
|
||||
dialer := newDialer(net.JoinHostPort(controldPublicDns, "53"))
|
||||
dnsTyp := uint16(0)
|
||||
if msg != nil && len(msg.Question) > 0 {
|
||||
dnsTyp = msg.Question[0].Qtype
|
||||
}
|
||||
|
||||
tcpNet, _ := r.uc.netForDNSType(dnsTyp)
|
||||
dnsClient := &dns.Client{
|
||||
Net: tcpNet,
|
||||
@@ -39,5 +38,5 @@ func (r *dotResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, erro
|
||||
}
|
||||
|
||||
answer, _, err := dnsClient.ExchangeContext(ctx, msg, endpoint)
|
||||
return answer, err
|
||||
return answer, wrapCertificateVerificationError(err)
|
||||
}
|
||||
|
||||
34
go.mod
34
go.mod
@@ -1,14 +1,15 @@
|
||||
module github.com/Control-D-Inc/ctrld
|
||||
|
||||
go 1.23
|
||||
go 1.23.0
|
||||
|
||||
toolchain go1.23.1
|
||||
toolchain go1.23.7
|
||||
|
||||
require (
|
||||
github.com/Masterminds/semver v1.5.0
|
||||
github.com/Masterminds/semver/v3 v3.2.1
|
||||
github.com/ameshkov/dnsstamps v1.0.3
|
||||
github.com/coreos/go-systemd/v22 v22.5.0
|
||||
github.com/cuonglm/osinfo v0.0.0-20230921071424-e0e1b1e0bbbf
|
||||
github.com/docker/go-units v0.5.0
|
||||
github.com/frankban/quicktest v1.14.6
|
||||
github.com/fsnotify/fsnotify v1.7.0
|
||||
github.com/go-playground/validator/v10 v10.11.1
|
||||
@@ -20,6 +21,7 @@ require (
|
||||
github.com/josharian/native v1.1.1-0.20230202152459-5c7d0dd6ab86
|
||||
github.com/kardianos/service v1.2.1
|
||||
github.com/mdlayher/ndp v1.0.1
|
||||
github.com/microsoft/wmi v0.24.5
|
||||
github.com/miekg/dns v1.1.58
|
||||
github.com/minio/selfupdate v0.6.0
|
||||
github.com/olekukonko/tablewriter v0.0.5
|
||||
@@ -27,16 +29,16 @@ require (
|
||||
github.com/prometheus/client_golang v1.19.1
|
||||
github.com/prometheus/client_model v0.5.0
|
||||
github.com/prometheus/prom2json v1.3.3
|
||||
github.com/quic-go/quic-go v0.42.0
|
||||
github.com/quic-go/quic-go v0.54.0
|
||||
github.com/rs/zerolog v1.28.0
|
||||
github.com/spf13/cobra v1.8.1
|
||||
github.com/spf13/pflag v1.0.5
|
||||
github.com/spf13/viper v1.16.0
|
||||
github.com/stretchr/testify v1.9.0
|
||||
github.com/vishvananda/netlink v1.2.1-beta.2
|
||||
golang.org/x/net v0.27.0
|
||||
golang.org/x/sync v0.7.0
|
||||
golang.org/x/sys v0.22.0
|
||||
golang.org/x/net v0.38.0
|
||||
golang.org/x/sync v0.12.0
|
||||
golang.org/x/sys v0.31.0
|
||||
golang.zx2c4.com/wireguard/windows v0.5.3
|
||||
tailscale.com v1.74.0
|
||||
)
|
||||
@@ -49,12 +51,12 @@ require (
|
||||
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect
|
||||
github.com/dblohm7/wingoes v0.0.0-20240119213807-a09d6be7affa // indirect
|
||||
github.com/go-json-experiment/json v0.0.0-20231102232822-2e55bd4e08b0 // indirect
|
||||
github.com/go-ole/go-ole v1.3.0 // indirect
|
||||
github.com/go-playground/locales v0.14.0 // indirect
|
||||
github.com/go-playground/universal-translator v0.18.0 // indirect
|
||||
github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 // indirect
|
||||
github.com/golang/protobuf v1.5.4 // indirect
|
||||
github.com/google/go-cmp v0.6.0 // indirect
|
||||
github.com/google/pprof v0.0.0-20240409012703-83162a5b38cd // indirect
|
||||
github.com/google/uuid v1.6.0 // indirect
|
||||
github.com/hashicorp/hcl v1.0.0 // indirect
|
||||
github.com/inconshreveable/mousetrap v1.1.0 // indirect
|
||||
github.com/jsimonetti/rtnetlink v1.4.0 // indirect
|
||||
@@ -70,12 +72,12 @@ require (
|
||||
github.com/mdlayher/packet v1.1.2 // indirect
|
||||
github.com/mdlayher/socket v0.5.0 // indirect
|
||||
github.com/mitchellh/mapstructure v1.5.0 // indirect
|
||||
github.com/onsi/ginkgo/v2 v2.9.5 // indirect
|
||||
github.com/pierrec/lz4/v4 v4.1.21 // indirect
|
||||
github.com/pkg/errors v0.9.1 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
|
||||
github.com/prometheus/common v0.48.0 // indirect
|
||||
github.com/prometheus/procfs v0.12.0 // indirect
|
||||
github.com/quic-go/qpack v0.4.0 // indirect
|
||||
github.com/quic-go/qpack v0.5.1 // indirect
|
||||
github.com/rivo/uniseg v0.4.4 // indirect
|
||||
github.com/rogpeppe/go-internal v1.11.0 // indirect
|
||||
github.com/spf13/afero v1.9.5 // indirect
|
||||
@@ -84,13 +86,13 @@ require (
|
||||
github.com/subosito/gotenv v1.4.2 // indirect
|
||||
github.com/u-root/uio v0.0.0-20240118234441-a3c409a6018e // indirect
|
||||
github.com/vishvananda/netns v0.0.4 // indirect
|
||||
go.uber.org/mock v0.4.0 // indirect
|
||||
go.uber.org/mock v0.5.0 // indirect
|
||||
go4.org/mem v0.0.0-20220726221520-4f986261bf13 // indirect
|
||||
go4.org/netipx v0.0.0-20231129151722-fdeea329fbba // indirect
|
||||
golang.org/x/crypto v0.25.0 // indirect
|
||||
golang.org/x/exp v0.0.0-20240119083558-1b970713d09a // indirect
|
||||
golang.org/x/crypto v0.36.0 // indirect
|
||||
golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 // indirect
|
||||
golang.org/x/mod v0.19.0 // indirect
|
||||
golang.org/x/text v0.16.0 // indirect
|
||||
golang.org/x/text v0.23.0 // indirect
|
||||
golang.org/x/tools v0.23.0 // indirect
|
||||
google.golang.org/protobuf v1.33.0 // indirect
|
||||
gopkg.in/ini.v1 v1.67.0 // indirect
|
||||
@@ -99,4 +101,4 @@ require (
|
||||
|
||||
replace github.com/mr-karan/doggo => github.com/Windscribe/doggo v0.0.0-20220919152748-2c118fc391f8
|
||||
|
||||
replace github.com/rs/zerolog => github.com/Windscribe/zerolog v0.0.0-20230503170159-e6aa153233be
|
||||
replace github.com/rs/zerolog => github.com/Windscribe/zerolog v0.0.0-20241206130353-cc6e8ef5397c
|
||||
|
||||
74
go.sum
74
go.sum
@@ -40,10 +40,10 @@ cloud.google.com/go/storage v1.14.0/go.mod h1:GrKmX003DSIwi9o29oFT7YDnHYwZoctc3f
|
||||
dmitri.shuralyov.com/gpu/mtl v0.0.0-20190408044501-666a987793e9/go.mod h1:H6x//7gZCb22OMCxBHrMx7a5I7Hp++hsVxbQ4BYO7hU=
|
||||
github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU=
|
||||
github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo=
|
||||
github.com/Masterminds/semver v1.5.0 h1:H65muMkzWKEuNDnfl9d70GUjFniHKHRbFPGBuZ3QEww=
|
||||
github.com/Masterminds/semver v1.5.0/go.mod h1:MB6lktGJrhw8PrUyiEoblNEGEQ+RzHPF078ddwwvV3Y=
|
||||
github.com/Windscribe/zerolog v0.0.0-20230503170159-e6aa153233be h1:qBKVRi7Mom5heOkyZ+NCIu9HZBiNCsRqrRe5t9pooik=
|
||||
github.com/Windscribe/zerolog v0.0.0-20230503170159-e6aa153233be/go.mod h1:/tk+P47gFdPXq4QYjvCmT5/Gsug2nagsFWBWhAiSi1w=
|
||||
github.com/Masterminds/semver/v3 v3.2.1 h1:RN9w6+7QoMeJVGyfmbcgs28Br8cvmnucEXnY0rYXWg0=
|
||||
github.com/Masterminds/semver/v3 v3.2.1/go.mod h1:qvl/7zhW3nngYb5+80sSMF+FG2BjYrf8m9wsX0PNOMQ=
|
||||
github.com/Windscribe/zerolog v0.0.0-20241206130353-cc6e8ef5397c h1:UqFsxmwiCh/DBvwJB0m7KQ2QFDd6DdUkosznfMppdhE=
|
||||
github.com/Windscribe/zerolog v0.0.0-20241206130353-cc6e8ef5397c/go.mod h1:bJsvje4Z08ROH4Nhs5iH600c3IkWhwp44iRc54W6wYQ=
|
||||
github.com/alexbrainman/sspi v0.0.0-20231016080023-1a75b4708caa h1:LHTHcTQiSGT7VVbI0o4wBRNQIgn917usHWOd6VAffYI=
|
||||
github.com/alexbrainman/sspi v0.0.0-20231016080023-1a75b4708caa/go.mod h1:cEWa1LVoE5KvSD9ONXsZrj0z6KqySlCCNKHlLzbqAt4=
|
||||
github.com/ameshkov/dnsstamps v1.0.3 h1:Srzik+J9mivH1alRACTbys2xOxs0lRH9qnTA7Y1OYVo=
|
||||
@@ -74,6 +74,8 @@ github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1
|
||||
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/dblohm7/wingoes v0.0.0-20240119213807-a09d6be7affa h1:h8TfIT1xc8FWbwwpmHn1J5i43Y0uZP97GqasGCzSRJk=
|
||||
github.com/dblohm7/wingoes v0.0.0-20240119213807-a09d6be7affa/go.mod h1:Nx87SkVqTKd8UtT+xu7sM/l+LgXs6c0aHrlKusR+2EQ=
|
||||
github.com/docker/go-units v0.5.0 h1:69rxXcBk27SvSaaxTtLh/8llcHD8vYHT7WSdRZ/jvr4=
|
||||
github.com/docker/go-units v0.5.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk=
|
||||
github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4=
|
||||
github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4=
|
||||
github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1mIlRU8Am5FuJP05cCM98=
|
||||
@@ -89,8 +91,8 @@ github.com/go-gl/glfw/v3.3/glfw v0.0.0-20191125211704-12ad95a8df72/go.mod h1:tQ2
|
||||
github.com/go-gl/glfw/v3.3/glfw v0.0.0-20200222043503-6f7a984d4dc4/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8=
|
||||
github.com/go-json-experiment/json v0.0.0-20231102232822-2e55bd4e08b0 h1:ymLjT4f35nQbASLnvxEde4XOBL+Sn7rFuV+FOJqkljg=
|
||||
github.com/go-json-experiment/json v0.0.0-20231102232822-2e55bd4e08b0/go.mod h1:6daplAwHHGbUGib4990V3Il26O0OC4aRyvewaaAihaA=
|
||||
github.com/go-logr/logr v1.4.2 h1:6pFjapn8bFcIbiKo3XT4j/BhANplGihG6tvd+8rYgrY=
|
||||
github.com/go-logr/logr v1.4.2/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
|
||||
github.com/go-ole/go-ole v1.3.0 h1:Dt6ye7+vXGIKZ7Xtk4s6/xVdGDQynvom7xCFEdWr6uE=
|
||||
github.com/go-ole/go-ole v1.3.0/go.mod h1:5LS6F96DhAwUc7C+1HLexzMXY1xGRSryjyPPKW6zv78=
|
||||
github.com/go-playground/assert/v2 v2.0.1 h1:MsBgLAaY856+nPRTKrp3/OZK38U/wa0CcBYNjji3q3A=
|
||||
github.com/go-playground/assert/v2 v2.0.1/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
|
||||
github.com/go-playground/locales v0.14.0 h1:u50s323jtVGugKlcYeyzC0etD1HifMjqmJqb8WugfUU=
|
||||
@@ -99,8 +101,6 @@ github.com/go-playground/universal-translator v0.18.0 h1:82dyy6p4OuJq4/CByFNOn/j
|
||||
github.com/go-playground/universal-translator v0.18.0/go.mod h1:UvRDBj+xPUEGrFYl+lu/H90nyDXpg0fqeB/AQUGNTVA=
|
||||
github.com/go-playground/validator/v10 v10.11.1 h1:prmOlTVv+YjZjmRmNSF3VmspqJIxJWXmqUsHwfTRRkQ=
|
||||
github.com/go-playground/validator/v10 v10.11.1/go.mod h1:i+3WkQ1FvaUjjxh1kSvIA4dMGDBiPU55YFDl0WbKdWU=
|
||||
github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 h1:tfuBGBXKqDEevZMzYi5KSi8KkcZtzBcTgAUUtapy0OI=
|
||||
github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572/go.mod h1:9Pwr4B2jHnOSGXyyzV8ROjYa2ojvAY6HCGYYfMoC3Ls=
|
||||
github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
|
||||
github.com/godbus/dbus/v5 v5.1.1-0.20230522191255-76236955d466 h1:sQspH8M4niEijh3PFscJRLDnkL547IeP7kpPe3uUhEg=
|
||||
github.com/godbus/dbus/v5 v5.1.1-0.20230522191255-76236955d466/go.mod h1:ZiQxhyQ+bbbfxUKVvjfO498oPYvtYhZzycal3G/NHmU=
|
||||
@@ -158,10 +158,10 @@ github.com/google/pprof v0.0.0-20200708004538-1a94d8640e99/go.mod h1:ZgVRPoUq/hf
|
||||
github.com/google/pprof v0.0.0-20201023163331-3e6fc7fc9c4c/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE=
|
||||
github.com/google/pprof v0.0.0-20201203190320-1bf35d6f28c2/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE=
|
||||
github.com/google/pprof v0.0.0-20201218002935-b9804c9f04c2/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE=
|
||||
github.com/google/pprof v0.0.0-20240409012703-83162a5b38cd h1:gbpYu9NMq8jhDVbvlGkMFWCjLFlqqEZjEmObmhUy6Vo=
|
||||
github.com/google/pprof v0.0.0-20240409012703-83162a5b38cd/go.mod h1:kf6iHlnVGwgKolg33glAes7Yg/8iWP8ukqeldJSO7jw=
|
||||
github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI=
|
||||
github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/googleapis/gax-go/v2 v2.0.4/go.mod h1:0Wqv26UfaUD9n4G6kQubkQ+KchISgw+vpHVxEJEs9eg=
|
||||
github.com/googleapis/gax-go/v2 v2.0.5/go.mod h1:DWXyrwAJ9X0FpwwEdw+IPEYBICEFu5mhpdKc/us6bOk=
|
||||
github.com/googleapis/google-cloud-go-testing v0.0.0-20200911160855-bcd43fbb19e8/go.mod h1:dvDLG8qkwmyD9a/MJJN3XJcT3xFxOKAvTZGvuZmac9g=
|
||||
@@ -207,11 +207,10 @@ github.com/leodido/go-urn v1.2.1 h1:BqpAaACuzVSgi/VLzGZIobT2z4v53pjosyNd9Yv6n/w=
|
||||
github.com/leodido/go-urn v1.2.1/go.mod h1:zt4jvISO2HfUBqxjfIshjdMTYS56ZS/qv49ictyFfxY=
|
||||
github.com/magiconair/properties v1.8.7 h1:IeQXZAiQcpL9mgcAe1Nu6cX9LLw6ExEHKjN0VQdvPDY=
|
||||
github.com/magiconair/properties v1.8.7/go.mod h1:Dhd985XPs7jluiymwWYZ0G4Z61jb3vdS329zhj2hYo0=
|
||||
github.com/mattn/go-colorable v0.1.12/go.mod h1:u5H1YNBxpqRaxsYJYSkiCWKzEfiAb1Gb520KVy5xxl4=
|
||||
github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA=
|
||||
github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg=
|
||||
github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94=
|
||||
github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
|
||||
github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
||||
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||
github.com/mattn/go-runewidth v0.0.9/go.mod h1:H031xJmbD/WCDINGzjvQ9THkh0rPKHF+m2gUSrubnMI=
|
||||
@@ -227,6 +226,8 @@ github.com/mdlayher/packet v1.1.2 h1:3Up1NG6LZrsgDVn6X4L9Ge/iyRyxFEFD9o6Pr3Q1nQY
|
||||
github.com/mdlayher/packet v1.1.2/go.mod h1:GEu1+n9sG5VtiRE4SydOmX5GTwyyYlteZiFU+x0kew4=
|
||||
github.com/mdlayher/socket v0.5.0 h1:ilICZmJcQz70vrWVes1MFera4jGiWNocSkykwwoy3XI=
|
||||
github.com/mdlayher/socket v0.5.0/go.mod h1:WkcBFfvyG8QENs5+hfQPl1X6Jpd2yeLIYgrGFmJiJxI=
|
||||
github.com/microsoft/wmi v0.24.5 h1:NT+WqhjKbEcg3ldmDsRMarWgHGkpeW+gMopSCfON0kM=
|
||||
github.com/microsoft/wmi v0.24.5/go.mod h1:1zbdSF0A+5OwTUII5p3hN7/K6KF2m3o27pSG6Y51VU8=
|
||||
github.com/miekg/dns v1.1.58 h1:ca2Hdkz+cDg/7eNF6V56jjzuZ4aCAE+DbVkILdQWG/4=
|
||||
github.com/miekg/dns v1.1.58/go.mod h1:Ypv+3b/KadlvW9vJfXOTf300O4UqaHFzFCuHz+rPkBY=
|
||||
github.com/minio/selfupdate v0.6.0 h1:i76PgT0K5xO9+hjzKcacQtO7+MjJ4JKA8Ak8XQ9DDwU=
|
||||
@@ -235,16 +236,13 @@ github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyua
|
||||
github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo=
|
||||
github.com/olekukonko/tablewriter v0.0.5 h1:P2Ga83D34wi1o9J6Wh1mRuqd4mF/x/lgBS7N7AbDhec=
|
||||
github.com/olekukonko/tablewriter v0.0.5/go.mod h1:hPp6KlRPjbx+hW8ykQs1w3UBbZlj6HuIJcUGPhkA7kY=
|
||||
github.com/onsi/ginkgo/v2 v2.9.5 h1:+6Hr4uxzP4XIUyAkg61dWBw8lb/gc4/X5luuxN/EC+Q=
|
||||
github.com/onsi/ginkgo/v2 v2.9.5/go.mod h1:tvAoo1QUJwNEU2ITftXTpR7R1RbCzoZUOs3RonqW57k=
|
||||
github.com/onsi/gomega v1.27.6 h1:ENqfyGeS5AX/rlXDd/ETokDz93u0YufY1Pgxuy/PvWE=
|
||||
github.com/onsi/gomega v1.27.6/go.mod h1:PIQNjfQwkP3aQAH7lf7j87O/5FiNr+ZR8+ipb+qQlhg=
|
||||
github.com/pelletier/go-toml/v2 v2.0.8 h1:0ctb6s9mE31h0/lhu+J6OPmVeDxJn+kYnJc2jZR9tGQ=
|
||||
github.com/pelletier/go-toml/v2 v2.0.8/go.mod h1:vuYfssBdrU2XDZ9bYydBu6t+6a6PYNcZljzZR9VXg+4=
|
||||
github.com/pierrec/lz4/v4 v4.1.14/go.mod h1:gZWDp/Ze/IJXGXf23ltt2EXimqmTUXEy0GFuRQyBid4=
|
||||
github.com/pierrec/lz4/v4 v4.1.21 h1:yOVMLb6qSIDP67pl/5F7RepeKYu/VmTyEXvuMI5d9mQ=
|
||||
github.com/pierrec/lz4/v4 v4.1.21/go.mod h1:gZWDp/Ze/IJXGXf23ltt2EXimqmTUXEy0GFuRQyBid4=
|
||||
github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA=
|
||||
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
|
||||
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||
github.com/pkg/sftp v1.13.1/go.mod h1:3HaPG6Dq1ILlpPZRO0HVMrsydcdLt6HRDccSgb87qRg=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
@@ -261,10 +259,10 @@ github.com/prometheus/procfs v0.12.0 h1:jluTpSng7V9hY0O2R9DzzJHYb2xULk9VTR1V1R/k
|
||||
github.com/prometheus/procfs v0.12.0/go.mod h1:pcuDEFsWDnvcgNzo4EEweacyhjeA9Zk3cnaOZAZEfOo=
|
||||
github.com/prometheus/prom2json v1.3.3 h1:IYfSMiZ7sSOfliBoo89PcufjWO4eAR0gznGcETyaUgo=
|
||||
github.com/prometheus/prom2json v1.3.3/go.mod h1:Pv4yIPktEkK7btWsrUTWDDDrnpUrAELaOCj+oFwlgmc=
|
||||
github.com/quic-go/qpack v0.4.0 h1:Cr9BXA1sQS2SmDUWjSofMPNKmvF6IiIfDRmgU0w1ZCo=
|
||||
github.com/quic-go/qpack v0.4.0/go.mod h1:UZVnYIfi5GRk+zI9UMaCPsmZ2xKJP7XBUvVyT1Knj9A=
|
||||
github.com/quic-go/quic-go v0.42.0 h1:uSfdap0eveIl8KXnipv9K7nlwZ5IqLlYOpJ58u5utpM=
|
||||
github.com/quic-go/quic-go v0.42.0/go.mod h1:132kz4kL3F9vxhW3CtQJLDVwcFe5wdWeJXXijhsO57M=
|
||||
github.com/quic-go/qpack v0.5.1 h1:giqksBPnT/HDtZ6VhtFKgoLOWmlyo9Ei6u9PqzIMbhI=
|
||||
github.com/quic-go/qpack v0.5.1/go.mod h1:+PC4XFrEskIVkcLzpEkbLqq1uCoxPhQuvK5rH1ZgaEg=
|
||||
github.com/quic-go/quic-go v0.54.0 h1:6s1YB9QotYI6Ospeiguknbp2Znb/jZYjZLRXn9kMQBg=
|
||||
github.com/quic-go/quic-go v0.54.0/go.mod h1:e68ZEaCdyviluZmy44P6Iey98v/Wfz6HCjQEm+l8zTY=
|
||||
github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
|
||||
github.com/rivo/uniseg v0.4.4 h1:8TfxU8dW6PdqD27gjM8MVNuicgxIjxpm4K7x4jp8sis=
|
||||
github.com/rivo/uniseg v0.4.4/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88=
|
||||
@@ -274,7 +272,7 @@ github.com/rogpeppe/go-internal v1.8.0/go.mod h1:WmiCO8CzOY8rg0OYDC4/i/2WRWAB6po
|
||||
github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs=
|
||||
github.com/rogpeppe/go-internal v1.11.0 h1:cWPaGQEPrBb5/AsnsZesgZZ9yb1OQ+GOISoDNXVBh4M=
|
||||
github.com/rogpeppe/go-internal v1.11.0/go.mod h1:ddIwULY96R17DhadqLgMfk9H9tvdUzkipdSkR5nkCZA=
|
||||
github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg=
|
||||
github.com/rs/xid v1.6.0/go.mod h1:7XoLgs4eV+QndskICGsho+ADou8ySMSjJKDIan90Nz0=
|
||||
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
|
||||
github.com/spf13/afero v1.9.5 h1:stMpOSZFs//0Lv29HduCmli3GUfpFoF3Y1Q/aXj/wVM=
|
||||
github.com/spf13/afero v1.9.5/go.mod h1:UBogFpq8E9Hx+xc5CNTTEpTnuHVmXDwZcZcE1eb/UhQ=
|
||||
@@ -322,8 +320,8 @@ go.opencensus.io v0.22.2/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw=
|
||||
go.opencensus.io v0.22.3/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw=
|
||||
go.opencensus.io v0.22.4/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw=
|
||||
go.opencensus.io v0.22.5/go.mod h1:5pWMHQbX5EPX2/62yrJeAkowc+lfs/XD7Uxpq3pI6kk=
|
||||
go.uber.org/mock v0.4.0 h1:VcM4ZOtdbR4f6VXfiOpwpVJDL6lCReaZ6mw31wqh7KU=
|
||||
go.uber.org/mock v0.4.0/go.mod h1:a6FSlNadKUHUa9IP5Vyt1zh4fC7uAwxMutEAscFbkZc=
|
||||
go.uber.org/mock v0.5.0 h1:KAMbZvZPyBPWgD14IrIQ38QCyjwpvVVV6K/bHl1IwQU=
|
||||
go.uber.org/mock v0.5.0/go.mod h1:ge71pBPLYDk7QIi1LupWxdAykm7KIEFchiOqd6z7qMM=
|
||||
go4.org/mem v0.0.0-20220726221520-4f986261bf13 h1:CbZeCBZ0aZj8EfVgnqQcYZgf0lpZ3H9rmp5nkDTAst8=
|
||||
go4.org/mem v0.0.0-20220726221520-4f986261bf13/go.mod h1:reUoABIJ9ikfM5sgtSF3Wushcza7+WeD01VB9Lirh3g=
|
||||
go4.org/netipx v0.0.0-20231129151722-fdeea329fbba h1:0b9z3AuHCjxk0x/opv64kcgZLBseWJUpBw5I82+2U4M=
|
||||
@@ -338,8 +336,8 @@ golang.org/x/crypto v0.0.0-20210421170649-83a5a9bb288b/go.mod h1:T9bdIzuCu7OtxOm
|
||||
golang.org/x/crypto v0.0.0-20211209193657-4570a0811e8b/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
|
||||
golang.org/x/crypto v0.0.0-20211215153901-e495a2d5b3d3/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
|
||||
golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
|
||||
golang.org/x/crypto v0.25.0 h1:ypSNr+bnYL2YhwoMt2zPxHFmbAN1KZs/njMG3hxUp30=
|
||||
golang.org/x/crypto v0.25.0/go.mod h1:T+wALwcMOSE0kXgUAnPAHqTLW+XHgcELELW8VaDgm/M=
|
||||
golang.org/x/crypto v0.36.0 h1:AnAEvhDddvBdpY+uR+MyHmuZzzNqXSe/GvuDeob5L34=
|
||||
golang.org/x/crypto v0.36.0/go.mod h1:Y4J0ReaxCR1IMaabaSMugxJES1EpwhBHhv2bDHklZvc=
|
||||
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||
golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||
golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8=
|
||||
@@ -350,8 +348,8 @@ golang.org/x/exp v0.0.0-20191227195350-da58074b4299/go.mod h1:2RIsYlXP63K8oxa1u0
|
||||
golang.org/x/exp v0.0.0-20200119233911-0405dc783f0a/go.mod h1:2RIsYlXP63K8oxa1u096TMicItID8zy7Y6sNkU49FU4=
|
||||
golang.org/x/exp v0.0.0-20200207192155-f17229e696bd/go.mod h1:J/WKrq2StrnmMY6+EHIKF9dgMWnmCNThgcyBT1FY9mM=
|
||||
golang.org/x/exp v0.0.0-20200224162631-6cc2880d07d6/go.mod h1:3jZMyOhIsHpP37uCMkUooju7aAi5cS1Q23tOzKc+0MU=
|
||||
golang.org/x/exp v0.0.0-20240119083558-1b970713d09a h1:Q8/wZp0KX97QFTc2ywcOE0YRjZPVIx+MXInMzdvQqcA=
|
||||
golang.org/x/exp v0.0.0-20240119083558-1b970713d09a/go.mod h1:idGWGoKP1toJGkd5/ig9ZLuPcZBC3ewk7SzmH0uou08=
|
||||
golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 h1:vr/HnozRka3pE4EsMEg1lgkXJkTFJCVUX+S/ZT6wYzM=
|
||||
golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842/go.mod h1:XtvwrStGgqGPLc4cjQfWqZHG1YFdYs6swckp8vpsjnc=
|
||||
golang.org/x/image v0.0.0-20190227222117-0694c2d4d067/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js=
|
||||
golang.org/x/image v0.0.0-20190802002840-cff245a6509b/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0=
|
||||
golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE=
|
||||
@@ -409,8 +407,8 @@ golang.org/x/net v0.0.0-20201209123823-ac852fbbde11/go.mod h1:m0MpNAwzfU5UDzcl9v
|
||||
golang.org/x/net v0.0.0-20201224014010-6772e930b67b/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
|
||||
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
|
||||
golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
|
||||
golang.org/x/net v0.27.0 h1:5K3Njcw06/l2y9vpGCSdcxWOYHOUk3dVNGDXN+FvAys=
|
||||
golang.org/x/net v0.27.0/go.mod h1:dDi0PyhWNoiUOrAS8uXv/vnScO4wnHQO4mj9fn/RytE=
|
||||
golang.org/x/net v0.38.0 h1:vRMAPTMaeGqVhG5QyLJHqNDwecKTomGeqbnfZyKlBI8=
|
||||
golang.org/x/net v0.38.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8=
|
||||
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
|
||||
golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
|
||||
golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
|
||||
@@ -430,8 +428,8 @@ golang.org/x/sync v0.0.0-20200317015054-43a5402ce75a/go.mod h1:RxMgew5VJxzue5/jJ
|
||||
golang.org/x/sync v0.0.0-20200625203802-6e8e738ad208/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M=
|
||||
golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
||||
golang.org/x/sync v0.12.0 h1:MHc5BpPuC30uJk597Ri8TV3CNZcTLu6B6z4lJy+g6Jw=
|
||||
golang.org/x/sync v0.12.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
|
||||
golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20190312061237-fead79001313/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
@@ -472,16 +470,16 @@ golang.org/x/sys v0.0.0-20210228012217-479acdf4ea46/go.mod h1:h1NjWce9XRLGQEsW7w
|
||||
golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210423185535-09eb48e85fd7/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20210806184541-e5e7981a1069/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220622161953-175b2fd9d664/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220817070843-5a390386f1f2/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.4.1-0.20230131160137-e7d7f63158de/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.22.0 h1:RI27ohtqKCnwULzJLqkv897zojh5/DwS/ENaMzUOaWI=
|
||||
golang.org/x/sys v0.22.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.31.0 h1:ioabZlmFYtWhL+TRYpcnNlLwhyxaM9kWTDEmfnprqik=
|
||||
golang.org/x/sys v0.31.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
|
||||
golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw=
|
||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||
golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
@@ -492,13 +490,11 @@ golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.3.4/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
|
||||
golang.org/x/text v0.16.0 h1:a94ExnEXNtEwYLGJSIUxnWoxoRz/ZcCsV63ROupILh4=
|
||||
golang.org/x/text v0.16.0/go.mod h1:GhwF1Be+LQoKShO3cGOHzqOgRrGaYc9AvblQOmPVHnI=
|
||||
golang.org/x/text v0.23.0 h1:D71I7dUrlY+VX0gQShAThNGHFxZ13dGLBHQLVl1mJlY=
|
||||
golang.org/x/text v0.23.0/go.mod h1:/BLNzu4aZCJ1+kcD0DNRotWKage4q2rGVAg4o22unh4=
|
||||
golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
|
||||
golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
|
||||
golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
|
||||
golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk=
|
||||
golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
|
||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY=
|
||||
|
||||
@@ -77,6 +77,7 @@ type Table struct {
|
||||
hostnameResolvers []HostnameResolver
|
||||
refreshers []refresher
|
||||
initOnce sync.Once
|
||||
stopOnce sync.Once
|
||||
refreshInterval int
|
||||
|
||||
dhcp *dhcp
|
||||
@@ -90,7 +91,9 @@ type Table struct {
|
||||
vni *virtualNetworkIface
|
||||
svcCfg ctrld.ServiceConfig
|
||||
quitCh chan struct{}
|
||||
stopCh chan struct{}
|
||||
selfIP string
|
||||
selfIPLock sync.RWMutex
|
||||
cdUID string
|
||||
ptrNameservers []string
|
||||
}
|
||||
@@ -103,6 +106,7 @@ func NewTable(cfg *ctrld.Config, selfIP, cdUID string, ns []string) *Table {
|
||||
return &Table{
|
||||
svcCfg: cfg.Service,
|
||||
quitCh: make(chan struct{}),
|
||||
stopCh: make(chan struct{}),
|
||||
selfIP: selfIP,
|
||||
cdUID: cdUID,
|
||||
ptrNameservers: ns,
|
||||
@@ -120,33 +124,80 @@ func (t *Table) AddLeaseFile(name string, format ctrld.LeaseFileFormat) {
|
||||
// RefreshLoop runs all the refresher to update new client info data.
|
||||
func (t *Table) RefreshLoop(ctx context.Context) {
|
||||
timer := time.NewTicker(time.Second * time.Duration(t.refreshInterval))
|
||||
defer timer.Stop()
|
||||
defer func() {
|
||||
timer.Stop()
|
||||
close(t.quitCh)
|
||||
}()
|
||||
for {
|
||||
select {
|
||||
case <-timer.C:
|
||||
for _, r := range t.refreshers {
|
||||
_ = r.refresh()
|
||||
}
|
||||
t.Refresh()
|
||||
case <-t.stopCh:
|
||||
return
|
||||
case <-ctx.Done():
|
||||
close(t.quitCh)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Init initializes all client info discovers.
|
||||
func (t *Table) Init() {
|
||||
t.initOnce.Do(t.init)
|
||||
}
|
||||
|
||||
// Refresh forces all discovers to retrieve new data.
|
||||
func (t *Table) Refresh() {
|
||||
for _, r := range t.refreshers {
|
||||
_ = r.refresh()
|
||||
}
|
||||
}
|
||||
|
||||
// Stop stops all the discovers.
|
||||
// It blocks until all the discovers done.
|
||||
func (t *Table) Stop() {
|
||||
t.stopOnce.Do(func() {
|
||||
close(t.stopCh)
|
||||
})
|
||||
<-t.quitCh
|
||||
}
|
||||
|
||||
// SelfIP returns the selfIP value of the Table in a thread-safe manner.
|
||||
func (t *Table) SelfIP() string {
|
||||
t.selfIPLock.RLock()
|
||||
defer t.selfIPLock.RUnlock()
|
||||
return t.selfIP
|
||||
}
|
||||
|
||||
// SetSelfIP sets the selfIP value of the Table in a thread-safe manner.
|
||||
func (t *Table) SetSelfIP(ip string) {
|
||||
t.selfIPLock.Lock()
|
||||
defer t.selfIPLock.Unlock()
|
||||
t.selfIP = ip
|
||||
t.dhcp.selfIP = t.selfIP
|
||||
t.dhcp.addSelf()
|
||||
}
|
||||
|
||||
// initSelfDiscover initializes necessary client metadata for self query.
|
||||
func (t *Table) initSelfDiscover() {
|
||||
t.dhcp = &dhcp{selfIP: t.selfIP}
|
||||
t.dhcp.addSelf()
|
||||
t.ipResolvers = append(t.ipResolvers, t.dhcp)
|
||||
t.macResolvers = append(t.macResolvers, t.dhcp)
|
||||
t.hostnameResolvers = append(t.hostnameResolvers, t.dhcp)
|
||||
}
|
||||
|
||||
func (t *Table) init() {
|
||||
// Custom client ID presents, use it as the only source.
|
||||
if _, clientID := controld.ParseRawUID(t.cdUID); clientID != "" {
|
||||
ctrld.ProxyLogger.Load().Debug().Msg("start self discovery")
|
||||
t.dhcp = &dhcp{selfIP: t.selfIP}
|
||||
t.dhcp.addSelf()
|
||||
t.ipResolvers = append(t.ipResolvers, t.dhcp)
|
||||
t.macResolvers = append(t.macResolvers, t.dhcp)
|
||||
t.hostnameResolvers = append(t.hostnameResolvers, t.dhcp)
|
||||
ctrld.ProxyLogger.Load().Debug().Msg("start self discovery with custom client id")
|
||||
t.initSelfDiscover()
|
||||
return
|
||||
}
|
||||
|
||||
// If we are running on platforms that should only do self discover, use it as the only source, too.
|
||||
if ctrld.SelfDiscover() {
|
||||
ctrld.ProxyLogger.Load().Debug().Msg("start self discovery on desktop platforms")
|
||||
t.initSelfDiscover()
|
||||
return
|
||||
}
|
||||
|
||||
@@ -168,11 +219,10 @@ func (t *Table) init() {
|
||||
}
|
||||
for platform, discover := range discovers {
|
||||
if err := discover.refresh(); err != nil {
|
||||
ctrld.ProxyLogger.Load().Error().Err(err).Msgf("could not init %s discover", platform)
|
||||
} else {
|
||||
t.hostnameResolvers = append(t.hostnameResolvers, discover)
|
||||
t.refreshers = append(t.refreshers, discover)
|
||||
ctrld.ProxyLogger.Load().Warn().Err(err).Msgf("failed to init %s discover", platform)
|
||||
}
|
||||
t.hostnameResolvers = append(t.hostnameResolvers, discover)
|
||||
t.refreshers = append(t.refreshers, discover)
|
||||
}
|
||||
}
|
||||
// Hosts file mapping.
|
||||
@@ -381,22 +431,30 @@ func (t *Table) lookupHostnameAll(ip, mac string) []*hostnameEntry {
|
||||
|
||||
// ListClients returns list of clients discovered by ctrld.
|
||||
func (t *Table) ListClients() []*Client {
|
||||
for _, r := range t.refreshers {
|
||||
_ = r.refresh()
|
||||
}
|
||||
t.Refresh()
|
||||
ipMap := make(map[string]*Client)
|
||||
il := []ipLister{t.dhcp, t.arp, t.ndp, t.ptr, t.mdns, t.vni}
|
||||
|
||||
for _, ir := range il {
|
||||
if ir == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, ip := range ir.List() {
|
||||
c, ok := ipMap[ip]
|
||||
if !ok {
|
||||
c = &Client{
|
||||
IP: netip.MustParseAddr(ip),
|
||||
Source: map[string]struct{}{ir.String(): {}},
|
||||
// Validate IP before using MustParseAddr
|
||||
if addr, err := netip.ParseAddr(ip); err == nil {
|
||||
c, ok := ipMap[ip]
|
||||
if !ok {
|
||||
c = &Client{
|
||||
IP: addr,
|
||||
Source: map[string]struct{}{},
|
||||
}
|
||||
ipMap[ip] = c
|
||||
}
|
||||
// Safely get source name
|
||||
if src := ir.String(); src != "" {
|
||||
c.Source[src] = struct{}{}
|
||||
}
|
||||
ipMap[ip] = c
|
||||
} else {
|
||||
c.Source[ir.String()] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -16,4 +16,5 @@ var clientInfoFiles = map[string]ctrld.LeaseFileFormat{
|
||||
"/var/dhcpd/var/db/dhcpd.leases": ctrld.IscDhcpd, // Pfsense
|
||||
"/home/pi/.router/run/dhcp/dnsmasq.leases": ctrld.Dnsmasq, // Firewalla
|
||||
"/var/lib/kea/dhcp4.leases": ctrld.KeaDHCP4, // Pfsense
|
||||
"/var/db/dnsmasq.leases": ctrld.Dnsmasq, // OPNsense
|
||||
}
|
||||
|
||||
@@ -74,7 +74,6 @@ func (m *mdns) lookupIPByHostname(name string, v6 bool) string {
|
||||
if value == name {
|
||||
if addr, err := netip.ParseAddr(key.(string)); err == nil && addr.Is6() == v6 {
|
||||
ip = addr.String()
|
||||
//lint:ignore S1008 This is used for readable.
|
||||
if addr.IsLoopback() { // Continue searching if this is loopback address.
|
||||
return true
|
||||
}
|
||||
@@ -92,6 +91,11 @@ func (m *mdns) init(quitCh chan struct{}) error {
|
||||
return err
|
||||
}
|
||||
|
||||
// Check if IPv6 is available once and use the result for the rest of the function.
|
||||
ctrld.ProxyLogger.Load().Debug().Msgf("checking for IPv6 availability in mdns init")
|
||||
ipv6 := ctrldnet.IPv6Available(context.Background())
|
||||
ctrld.ProxyLogger.Load().Debug().Msgf("IPv6 is %v in mdns init", ipv6)
|
||||
|
||||
v4ConnList := make([]*net.UDPConn, 0, len(ifaces))
|
||||
v6ConnList := make([]*net.UDPConn, 0, len(ifaces))
|
||||
for _, iface := range ifaces {
|
||||
@@ -102,7 +106,8 @@ func (m *mdns) init(quitCh chan struct{}) error {
|
||||
v4ConnList = append(v4ConnList, conn)
|
||||
go m.readLoop(conn)
|
||||
}
|
||||
if ctrldnet.IPv6Available(context.Background()) {
|
||||
|
||||
if ipv6 {
|
||||
if conn, err := net.ListenMulticastUDP("udp6", &iface, mdnsV6Addr); err == nil {
|
||||
v6ConnList = append(v6ConnList, conn)
|
||||
go m.readLoop(conn)
|
||||
|
||||
@@ -67,4 +67,16 @@ var services = [...]string{
|
||||
|
||||
// Merlin
|
||||
"_alexa._tcp",
|
||||
|
||||
// Newer Android TV devices
|
||||
"_androidtvremote2._tcp.local.",
|
||||
|
||||
// https://esphome.io/
|
||||
"_esphomelib._tcp.local.",
|
||||
|
||||
// https://www.home-assistant.io/
|
||||
"_home-assistant._tcp.local.",
|
||||
|
||||
// https://kno.wled.ge/
|
||||
"_wled._tcp.local.",
|
||||
}
|
||||
|
||||
@@ -104,7 +104,6 @@ func (p *ptrDiscover) lookupIPByHostname(name string, v6 bool) string {
|
||||
if value == name {
|
||||
if addr, err := netip.ParseAddr(key.(string)); err == nil && addr.Is6() == v6 {
|
||||
ip = addr.String()
|
||||
//lint:ignore S1008 This is used for readable.
|
||||
if addr.IsLoopback() { // Continue searching if this is loopback address.
|
||||
return true
|
||||
}
|
||||
@@ -120,8 +119,7 @@ func (p *ptrDiscover) lookupIPByHostname(name string, v6 bool) string {
|
||||
// is reachable, set p.serverDown to false, so p.lookupHostname can continue working.
|
||||
func (p *ptrDiscover) checkServer() {
|
||||
bo := backoff.NewBackoff("ptrDiscover", func(format string, args ...any) {}, time.Minute*5)
|
||||
m := new(dns.Msg)
|
||||
m.SetQuestion(".", dns.TypeNS)
|
||||
m := (&ctrld.UpstreamConfig{}).VerifyMsg()
|
||||
ping := func() error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel()
|
||||
|
||||
@@ -3,6 +3,7 @@ package clientinfo
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"os/exec"
|
||||
"strings"
|
||||
@@ -44,9 +45,9 @@ func (u *ubiosDiscover) refreshDevices() error {
|
||||
cmd := exec.Command("/usr/bin/mongo", "localhost:27117/ace", "--quiet", "--eval", `
|
||||
DBQuery.shellBatchSize = 256;
|
||||
db.user.find({name: {$exists: true, $ne: ""}}, {_id:0, mac:1, name:1});`)
|
||||
b, err := cmd.Output()
|
||||
b, err := cmd.CombinedOutput()
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("out: %s, err: %w", string(b), err)
|
||||
}
|
||||
return u.storeDevices(bytes.NewReader(b))
|
||||
}
|
||||
|
||||
@@ -24,10 +24,19 @@ import (
|
||||
|
||||
const (
|
||||
apiDomainCom = "api.controld.com"
|
||||
apiDomainComIPv4 = "147.185.34.1"
|
||||
apiDomainComIPv6 = "2606:1a40:3::1"
|
||||
apiDomainDev = "api.controld.dev"
|
||||
resolverDataURLCom = "https://api.controld.com/utility"
|
||||
resolverDataURLDev = "https://api.controld.dev/utility"
|
||||
apiDomainDevIPv4 = "23.171.240.84"
|
||||
apiURLCom = "https://api.controld.com"
|
||||
apiURLDev = "https://api.controld.dev"
|
||||
resolverDataURLCom = apiURLCom + "/utility"
|
||||
resolverDataURLDev = apiURLDev + "/utility"
|
||||
logURLCom = apiURLCom + "/logs"
|
||||
logURLDev = apiURLDev + "/logs"
|
||||
InvalidConfigCode = 40402
|
||||
defaultTimeout = 20 * time.Second
|
||||
sendLogTimeout = 300 * time.Second
|
||||
)
|
||||
|
||||
// ResolverConfig represents Control D resolver data.
|
||||
@@ -36,6 +45,7 @@ type ResolverConfig struct {
|
||||
Ctrld struct {
|
||||
CustomConfig string `json:"custom_config"`
|
||||
CustomLastUpdate int64 `json:"custom_last_update"`
|
||||
VersionTarget string `json:"version_target"`
|
||||
} `json:"ctrld"`
|
||||
Exclude []string `json:"exclude"`
|
||||
UID string `json:"uid"`
|
||||
@@ -49,14 +59,14 @@ type utilityResponse struct {
|
||||
} `json:"body"`
|
||||
}
|
||||
|
||||
type UtilityErrorResponse struct {
|
||||
type ErrorResponse struct {
|
||||
ErrorField struct {
|
||||
Message string `json:"message"`
|
||||
Code int `json:"code"`
|
||||
} `json:"error"`
|
||||
}
|
||||
|
||||
func (u UtilityErrorResponse) Error() string {
|
||||
func (u ErrorResponse) Error() string {
|
||||
return u.ErrorField.Message
|
||||
}
|
||||
|
||||
@@ -71,6 +81,12 @@ type UtilityOrgRequest struct {
|
||||
Hostname string `json:"hostname"`
|
||||
}
|
||||
|
||||
// LogsRequest contains request data for sending runtime logs to API.
|
||||
type LogsRequest struct {
|
||||
UID string `json:"uid"`
|
||||
Data io.ReadCloser `json:"-"`
|
||||
}
|
||||
|
||||
// FetchResolverConfig fetch Control D config for given uid.
|
||||
func FetchResolverConfig(rawUID, version string, cdDev bool) (*ResolverConfig, error) {
|
||||
uid, clientID := ParseRawUID(rawUID)
|
||||
@@ -123,42 +139,19 @@ func postUtilityAPI(version string, cdDev, lastUpdatedFailed bool, body io.Reade
|
||||
}
|
||||
req.URL.RawQuery = q.Encode()
|
||||
req.Header.Add("Content-Type", "application/json")
|
||||
transport := http.DefaultTransport.(*http.Transport).Clone()
|
||||
transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
apiDomain := apiDomainCom
|
||||
if cdDev {
|
||||
apiDomain = apiDomainDev
|
||||
}
|
||||
ips := ctrld.LookupIP(apiDomain)
|
||||
if len(ips) == 0 {
|
||||
ctrld.ProxyLogger.Load().Warn().Msgf("No IPs found for %s, connecting to %s", apiDomain, addr)
|
||||
return ctrldnet.Dialer.DialContext(ctx, network, addr)
|
||||
}
|
||||
ctrld.ProxyLogger.Load().Debug().Msgf("API IPs: %v", ips)
|
||||
_, port, _ := net.SplitHostPort(addr)
|
||||
addrs := make([]string, len(ips))
|
||||
for i := range ips {
|
||||
addrs[i] = net.JoinHostPort(ips[i], port)
|
||||
}
|
||||
d := &ctrldnet.ParallelDialer{}
|
||||
return d.DialContext(ctx, network, addrs)
|
||||
}
|
||||
|
||||
if router.Name() == ddwrt.Name || runtime.GOOS == "android" {
|
||||
transport.TLSClientConfig = &tls.Config{RootCAs: certs.CACertPool()}
|
||||
}
|
||||
client := http.Client{
|
||||
Timeout: 10 * time.Second,
|
||||
transport := apiTransport(cdDev)
|
||||
client := &http.Client{
|
||||
Timeout: defaultTimeout,
|
||||
Transport: transport,
|
||||
}
|
||||
resp, err := client.Do(req)
|
||||
resp, err := doWithFallback(client, req, apiServerIP(cdDev))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("client.Do: %w", err)
|
||||
return nil, fmt.Errorf("postUtilityAPI client.Do: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
d := json.NewDecoder(resp.Body)
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
errResp := &UtilityErrorResponse{}
|
||||
errResp := &ErrorResponse{}
|
||||
if err := d.Decode(errResp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -172,6 +165,43 @@ func postUtilityAPI(version string, cdDev, lastUpdatedFailed bool, body io.Reade
|
||||
return &ur.Body.Resolver, nil
|
||||
}
|
||||
|
||||
// SendLogs sends runtime log to ControlD API.
|
||||
func SendLogs(lr *LogsRequest, cdDev bool) error {
|
||||
defer lr.Data.Close()
|
||||
apiUrl := logURLCom
|
||||
if cdDev {
|
||||
apiUrl = logURLDev
|
||||
}
|
||||
req, err := http.NewRequest("POST", apiUrl, lr.Data)
|
||||
if err != nil {
|
||||
return fmt.Errorf("http.NewRequest: %w", err)
|
||||
}
|
||||
q := req.URL.Query()
|
||||
q.Set("uid", lr.UID)
|
||||
req.URL.RawQuery = q.Encode()
|
||||
req.Header.Add("Content-Type", "application/x-www-form-urlencoded")
|
||||
transport := apiTransport(cdDev)
|
||||
client := &http.Client{
|
||||
Timeout: sendLogTimeout,
|
||||
Transport: transport,
|
||||
}
|
||||
resp, err := doWithFallback(client, req, apiServerIP(cdDev))
|
||||
if err != nil {
|
||||
return fmt.Errorf("SendLogs client.Do: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
d := json.NewDecoder(resp.Body)
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
errResp := &ErrorResponse{}
|
||||
if err := d.Decode(errResp); err != nil {
|
||||
return err
|
||||
}
|
||||
return errResp
|
||||
}
|
||||
_, _ = io.Copy(io.Discard, resp.Body)
|
||||
return nil
|
||||
}
|
||||
|
||||
// ParseRawUID parse the input raw UID, returning real UID and ClientID.
|
||||
// The raw UID can have 2 forms:
|
||||
//
|
||||
@@ -181,3 +211,94 @@ func ParseRawUID(rawUID string) (string, string) {
|
||||
uid, clientID, _ := strings.Cut(rawUID, "/")
|
||||
return uid, clientID
|
||||
}
|
||||
|
||||
// apiTransport returns an HTTP transport for connecting to ControlD API endpoint.
|
||||
func apiTransport(cdDev bool) *http.Transport {
|
||||
transport := http.DefaultTransport.(*http.Transport).Clone()
|
||||
transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
apiDomain := apiDomainCom
|
||||
apiIpsV4 := []string{apiDomainComIPv4}
|
||||
apiIpsV6 := []string{apiDomainComIPv6}
|
||||
apiIPs := []string{apiDomainComIPv4, apiDomainComIPv6}
|
||||
if cdDev {
|
||||
apiDomain = apiDomainDev
|
||||
apiIpsV4 = []string{apiDomainDevIPv4}
|
||||
apiIpsV6 = []string{}
|
||||
apiIPs = []string{apiDomainDevIPv4}
|
||||
}
|
||||
|
||||
ips := ctrld.LookupIP(apiDomain)
|
||||
if len(ips) == 0 {
|
||||
ctrld.ProxyLogger.Load().Warn().Msgf("No IPs found for %s, use direct IPs: %v", apiDomain, apiIPs)
|
||||
ips = apiIPs
|
||||
}
|
||||
|
||||
// Separate IPv4 and IPv6 addresses
|
||||
var ipv4s, ipv6s []string
|
||||
for _, ip := range ips {
|
||||
if strings.Contains(ip, ":") {
|
||||
ipv6s = append(ipv6s, ip)
|
||||
} else {
|
||||
ipv4s = append(ipv4s, ip)
|
||||
}
|
||||
}
|
||||
|
||||
dial := func(ctx context.Context, network string, addrs []string) (net.Conn, error) {
|
||||
d := &ctrldnet.ParallelDialer{}
|
||||
return d.DialContext(ctx, network, addrs, ctrld.ProxyLogger.Load())
|
||||
}
|
||||
_, port, _ := net.SplitHostPort(addr)
|
||||
|
||||
// Try IPv4 first
|
||||
if len(ipv4s) > 0 {
|
||||
if conn, err := dial(ctx, "tcp4", addrsFromPort(ipv4s, port)); err == nil {
|
||||
return conn, nil
|
||||
}
|
||||
}
|
||||
// Fallback to direct IPv4
|
||||
if conn, err := dial(ctx, "tcp4", addrsFromPort(apiIpsV4, port)); err == nil {
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
// Fallback to IPv6 if available
|
||||
if len(ipv6s) > 0 {
|
||||
if conn, err := dial(ctx, "tcp6", addrsFromPort(ipv6s, port)); err == nil {
|
||||
return conn, nil
|
||||
}
|
||||
}
|
||||
// Fallback to direct IPv6
|
||||
return dial(ctx, "tcp6", addrsFromPort(apiIpsV6, port))
|
||||
}
|
||||
if router.Name() == ddwrt.Name || runtime.GOOS == "android" {
|
||||
transport.TLSClientConfig = &tls.Config{RootCAs: certs.CACertPool()}
|
||||
}
|
||||
return transport
|
||||
}
|
||||
|
||||
func addrsFromPort(ips []string, port string) []string {
|
||||
addrs := make([]string, len(ips))
|
||||
for i, ip := range ips {
|
||||
addrs[i] = net.JoinHostPort(ip, port)
|
||||
}
|
||||
return addrs
|
||||
}
|
||||
|
||||
func doWithFallback(client *http.Client, req *http.Request, apiIp string) (*http.Response, error) {
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
ctrld.ProxyLogger.Load().Warn().Err(err).Msgf("failed to send request, fallback to direct IP: %s", apiIp)
|
||||
ipReq := req.Clone(req.Context())
|
||||
ipReq.Host = apiIp
|
||||
ipReq.URL.Host = apiIp
|
||||
resp, err = client.Do(ipReq)
|
||||
}
|
||||
return resp, err
|
||||
}
|
||||
|
||||
// apiServerIP returns the direct IP to connect to API server.
|
||||
func apiServerIP(cdDev bool) string {
|
||||
if cdDev {
|
||||
return apiDomainDevIPv4
|
||||
}
|
||||
return apiDomainComIPv4
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ package net
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"os"
|
||||
"os/signal"
|
||||
@@ -11,28 +12,34 @@ import (
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
"tailscale.com/logtail/backoff"
|
||||
)
|
||||
|
||||
const (
|
||||
controldIPv6Test = "ipv6.controld.io"
|
||||
v4BootstrapDNS = "76.76.2.22:53"
|
||||
v6BootstrapDNS = "[2606:1a40::22]:53"
|
||||
v4BootstrapDNS = "76.76.2.22:53"
|
||||
v6BootstrapDNS = "[2606:1a40::22]:53"
|
||||
v6BootstrapIP = "2606:1a40::22"
|
||||
defaultHTTPSPort = "443"
|
||||
defaultHTTPPort = "80"
|
||||
defaultDNSPort = "53"
|
||||
probeStackTimeout = 2 * time.Second
|
||||
)
|
||||
|
||||
var commonIPv6Ports = []string{defaultHTTPSPort, defaultHTTPPort, defaultDNSPort}
|
||||
|
||||
var Dialer = &net.Dialer{
|
||||
Resolver: &net.Resolver{
|
||||
PreferGo: true,
|
||||
Dial: func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
d := ParallelDialer{}
|
||||
d.Timeout = 10 * time.Second
|
||||
return d.DialContext(ctx, "udp", []string{v4BootstrapDNS, v6BootstrapDNS})
|
||||
l := zerolog.New(io.Discard)
|
||||
return d.DialContext(ctx, "udp", []string{v4BootstrapDNS, v6BootstrapDNS}, &l)
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
const probeStackTimeout = 2 * time.Second
|
||||
|
||||
var probeStackDialer = &net.Dialer{
|
||||
Resolver: Dialer.Resolver,
|
||||
Timeout: probeStackTimeout,
|
||||
@@ -48,9 +55,29 @@ func init() {
|
||||
stackOnce.Store(new(sync.Once))
|
||||
}
|
||||
|
||||
func supportIPv6(ctx context.Context) bool {
|
||||
_, err := probeStackDialer.DialContext(ctx, "tcp6", net.JoinHostPort(controldIPv6Test, "443"))
|
||||
return err == nil
|
||||
// supportIPv6 checks for IPv6 connectivity by attempting to connect to predefined ports
|
||||
// on a specific IPv6 address.
|
||||
// Returns a boolean indicating if IPv6 is supported and the port on which the connection succeeded.
|
||||
// If no connection is successful, returns false and an empty string.
|
||||
func supportIPv6(ctx context.Context) (supported bool, successPort string) {
|
||||
for _, port := range commonIPv6Ports {
|
||||
if canConnectToIPv6Port(ctx, port) {
|
||||
return true, string(port)
|
||||
}
|
||||
}
|
||||
return false, ""
|
||||
}
|
||||
|
||||
// canConnectToIPv6Port attempts to establish a TCP connection to the specified port
|
||||
// using IPv6. Returns true if the connection was successful.
|
||||
func canConnectToIPv6Port(ctx context.Context, port string) bool {
|
||||
address := net.JoinHostPort(v6BootstrapIP, port)
|
||||
conn, err := probeStackDialer.DialContext(ctx, "tcp6", address)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
_ = conn.Close()
|
||||
return true
|
||||
}
|
||||
|
||||
func supportListenIPv6Local() bool {
|
||||
@@ -104,7 +131,8 @@ func SupportsIPv6ListenLocal() bool {
|
||||
|
||||
// IPv6Available is like SupportsIPv6, but always do the check without caching.
|
||||
func IPv6Available(ctx context.Context) bool {
|
||||
return supportIPv6(ctx)
|
||||
hasV6, _ := supportIPv6(ctx)
|
||||
return hasV6
|
||||
}
|
||||
|
||||
// IsIPv6 checks if the provided IP is v6.
|
||||
@@ -133,7 +161,7 @@ type ParallelDialer struct {
|
||||
net.Dialer
|
||||
}
|
||||
|
||||
func (d *ParallelDialer) DialContext(ctx context.Context, network string, addrs []string) (net.Conn, error) {
|
||||
func (d *ParallelDialer) DialContext(ctx context.Context, network string, addrs []string, logger *zerolog.Logger) (net.Conn, error) {
|
||||
if len(addrs) == 0 {
|
||||
return nil, errors.New("empty addresses")
|
||||
}
|
||||
@@ -153,11 +181,16 @@ func (d *ParallelDialer) DialContext(ctx context.Context, network string, addrs
|
||||
for _, addr := range addrs {
|
||||
go func(addr string) {
|
||||
defer wg.Done()
|
||||
logger.Debug().Msgf("dialing to %s", addr)
|
||||
conn, err := d.Dialer.DialContext(ctx, network, addr)
|
||||
if err != nil {
|
||||
logger.Debug().Msgf("failed to dial %s: %v", addr, err)
|
||||
}
|
||||
select {
|
||||
case ch <- ¶llelDialerResult{conn: conn, err: err}:
|
||||
case <-done:
|
||||
if conn != nil {
|
||||
logger.Debug().Msgf("connection closed: %s", conn.RemoteAddr())
|
||||
conn.Close()
|
||||
}
|
||||
}
|
||||
@@ -168,6 +201,7 @@ func (d *ParallelDialer) DialContext(ctx context.Context, network string, addrs
|
||||
for res := range ch {
|
||||
if res.err == nil {
|
||||
cancel()
|
||||
logger.Debug().Msgf("connected to %s", res.conn.RemoteAddr())
|
||||
return res.conn, res.err
|
||||
}
|
||||
errs = append(errs, res.err)
|
||||
|
||||
@@ -12,7 +12,12 @@ func TestProbeStackTimeout(t *testing.T) {
|
||||
go func() {
|
||||
defer close(done)
|
||||
close(started)
|
||||
supportIPv6(context.Background())
|
||||
hasV6, port := supportIPv6(context.Background())
|
||||
if hasV6 {
|
||||
t.Logf("connect to port %s using ipv6: %v", port, hasV6)
|
||||
} else {
|
||||
t.Log("ipv6 is not available")
|
||||
}
|
||||
}()
|
||||
|
||||
<-started
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"net"
|
||||
|
||||
"tailscale.com/net/dns/resolvconffile"
|
||||
"tailscale.com/util/dnsname"
|
||||
)
|
||||
|
||||
const resolvconfPath = "/etc/resolv.conf"
|
||||
@@ -22,7 +23,7 @@ func NameServersWithPort() []string {
|
||||
return ns
|
||||
}
|
||||
|
||||
func NameServers(_ string) []string {
|
||||
func NameServers() []string {
|
||||
c, err := resolvconffile.ParseFile(resolvconfPath)
|
||||
if err != nil {
|
||||
return nil
|
||||
@@ -33,3 +34,12 @@ func NameServers(_ string) []string {
|
||||
}
|
||||
return ns
|
||||
}
|
||||
|
||||
// SearchDomains returns the current search domains config in /etc/resolv.conf file.
|
||||
func SearchDomains() ([]dnsname.FQDN, error) {
|
||||
c, err := resolvconffile.ParseFile(resolvconfPath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return c.SearchDomains, nil
|
||||
}
|
||||
|
||||
@@ -9,7 +9,7 @@ import (
|
||||
)
|
||||
|
||||
func TestNameServers(t *testing.T) {
|
||||
ns := NameServers("")
|
||||
ns := NameServers()
|
||||
require.NotNil(t, ns)
|
||||
t.Log(ns)
|
||||
}
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"errors"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
)
|
||||
|
||||
@@ -28,3 +29,62 @@ func interfaceNameFromReader(r io.Reader) (string, error) {
|
||||
}
|
||||
return "", errors.New("not found")
|
||||
}
|
||||
|
||||
// AdditionalConfigFiles returns a list of Dnsmasq configuration files found in the "/tmp/etc" directory.
|
||||
func AdditionalConfigFiles() []string {
|
||||
if paths, err := filepath.Glob("/tmp/etc/dnsmasq-*.conf"); err == nil {
|
||||
return paths
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// AdditionalLeaseFiles returns a list of lease file paths corresponding to the Dnsmasq configuration files.
|
||||
func AdditionalLeaseFiles() []string {
|
||||
cfgFiles := AdditionalConfigFiles()
|
||||
if len(cfgFiles) == 0 {
|
||||
return nil
|
||||
}
|
||||
leaseFiles := make([]string, 0, len(cfgFiles))
|
||||
for _, cfgFile := range cfgFiles {
|
||||
if leaseFile := leaseFileFromConfigFileName(cfgFile); leaseFile != "" {
|
||||
leaseFiles = append(leaseFiles, leaseFile)
|
||||
|
||||
} else {
|
||||
leaseFiles = append(leaseFiles, defaultLeaseFileFromConfigPath(cfgFile))
|
||||
}
|
||||
}
|
||||
return leaseFiles
|
||||
}
|
||||
|
||||
// leaseFileFromConfigFileName retrieves the DHCP lease file path by reading and parsing the provided configuration file.
|
||||
func leaseFileFromConfigFileName(cfgFile string) string {
|
||||
if f, err := os.Open(cfgFile); err == nil {
|
||||
return leaseFileFromReader(f)
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// leaseFileFromReader parses the given io.Reader for the "dhcp-leasefile" configuration and returns its value as a string.
|
||||
func leaseFileFromReader(r io.Reader) string {
|
||||
scanner := bufio.NewScanner(r)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if strings.HasPrefix(line, "#") {
|
||||
continue
|
||||
}
|
||||
before, after, found := strings.Cut(line, "=")
|
||||
if !found {
|
||||
continue
|
||||
}
|
||||
if before == "dhcp-leasefile" {
|
||||
return after
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// defaultLeaseFileFromConfigPath generates the default lease file path based on the provided configuration file path.
|
||||
func defaultLeaseFileFromConfigPath(path string) string {
|
||||
name := filepath.Base(path)
|
||||
return filepath.Join("/var/lib/misc", strings.TrimSuffix(name, ".conf")+".leases")
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package dnsmasq
|
||||
|
||||
import (
|
||||
"io"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
@@ -44,3 +45,49 @@ interface=eth0
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_leaseFileFromReader(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
in io.Reader
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
"default",
|
||||
strings.NewReader(`
|
||||
dhcp-script=/sbin/dhcpc_lease
|
||||
dhcp-leasefile=/var/lib/misc/dnsmasq-1.leases
|
||||
script-arp
|
||||
`),
|
||||
"/var/lib/misc/dnsmasq-1.leases",
|
||||
},
|
||||
{
|
||||
"non-default",
|
||||
strings.NewReader(`
|
||||
dhcp-script=/sbin/dhcpc_lease
|
||||
dhcp-leasefile=/tmp/var/lib/misc/dnsmasq-1.leases
|
||||
script-arp
|
||||
`),
|
||||
"/tmp/var/lib/misc/dnsmasq-1.leases",
|
||||
},
|
||||
{
|
||||
"missing",
|
||||
strings.NewReader(`
|
||||
dhcp-script=/sbin/dhcpc_lease
|
||||
script-arp
|
||||
`),
|
||||
"",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
if got := leaseFileFromReader(tc.in); got != tc.expected {
|
||||
t.Errorf("leaseFileFromReader() = %v, want %v", got, tc.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"errors"
|
||||
"html/template"
|
||||
"net"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
@@ -26,7 +27,13 @@ max-cache-ttl=0
|
||||
{{- end}}
|
||||
`
|
||||
|
||||
const MerlinPostConfPath = "/jffs/scripts/dnsmasq.postconf"
|
||||
const (
|
||||
MerlinConfPath = "/tmp/etc/dnsmasq.conf"
|
||||
MerlinJffsConfDir = "/jffs/configs"
|
||||
MerlinJffsConfPath = "/jffs/configs/dnsmasq.conf"
|
||||
MerlinPostConfPath = "/jffs/scripts/dnsmasq.postconf"
|
||||
)
|
||||
|
||||
const MerlinPostConfMarker = `# GENERATED BY ctrld - EOF`
|
||||
const MerlinPostConfTmpl = `# GENERATED BY ctrld - DO NOT MODIFY
|
||||
|
||||
@@ -157,3 +164,27 @@ func FirewallaSelfInterfaces() []*net.Interface {
|
||||
}
|
||||
return ifaces
|
||||
}
|
||||
|
||||
const (
|
||||
ubios43ConfPath = "/run/dnsmasq.dhcp.conf.d"
|
||||
ubios42ConfPath = "/run/dnsmasq.conf.d"
|
||||
ubios43PidFile = "/run/dnsmasq-main.pid"
|
||||
ubios42PidFile = "/run/dnsmasq.pid"
|
||||
UbiosConfName = "zzzctrld.conf"
|
||||
)
|
||||
|
||||
// UbiosConfPath returns the appropriate configuration path based on the system's directory structure.
|
||||
func UbiosConfPath() string {
|
||||
if st, _ := os.Stat(ubios43ConfPath); st != nil && st.IsDir() {
|
||||
return ubios43ConfPath
|
||||
}
|
||||
return ubios42ConfPath
|
||||
}
|
||||
|
||||
// UbiosPidFile returns the appropriate dnsmasq pid file based on the system's directory structure.
|
||||
func UbiosPidFile() string {
|
||||
if st, _ := os.Stat(ubios43PidFile); st != nil && !st.IsDir() {
|
||||
return ubios43PidFile
|
||||
}
|
||||
return ubios42PidFile
|
||||
}
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/kardianos/service"
|
||||
@@ -181,7 +182,7 @@ func ContentFilteringEnabled() bool {
|
||||
// DnsShieldEnabled reports whether DNS Shield is enabled.
|
||||
// See: https://community.ui.com/releases/UniFi-OS-Dream-Machines-3-2-7/251dfc1e-f4dd-4264-a080-3be9d8b9e02b
|
||||
func DnsShieldEnabled() bool {
|
||||
buf, err := os.ReadFile("/var/run/dnsmasq.conf.d/dns.conf")
|
||||
buf, err := os.ReadFile(filepath.Join(dnsmasq.UbiosConfPath(), "dns.conf"))
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -3,8 +3,10 @@ package merlin
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
"unicode"
|
||||
@@ -19,10 +21,18 @@ import (
|
||||
|
||||
const Name = "merlin"
|
||||
|
||||
// nvramKvMap is a map of NVRAM key-value pairs used to configure and manage Merlin-specific settings.
|
||||
var nvramKvMap = map[string]string{
|
||||
"dnspriv_enable": "0", // Ensure Merlin native DoT disabled.
|
||||
}
|
||||
|
||||
// dnsmasqConfig represents configuration paths for dnsmasq operations in Merlin firmware.
|
||||
type dnsmasqConfig struct {
|
||||
confPath string
|
||||
jffsConfPath string
|
||||
}
|
||||
|
||||
// Merlin represents a configuration handler for setting up and managing ctrld on Merlin routers.
|
||||
type Merlin struct {
|
||||
cfg *ctrld.Config
|
||||
}
|
||||
@@ -32,18 +42,22 @@ func New(cfg *ctrld.Config) *Merlin {
|
||||
return &Merlin{cfg: cfg}
|
||||
}
|
||||
|
||||
// ConfigureService configures the service based on the provided configuration. It returns an error if the configuration fails.
|
||||
func (m *Merlin) ConfigureService(config *service.Config) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Install sets up the necessary configurations and services required for the Merlin instance to function properly.
|
||||
func (m *Merlin) Install(_ *service.Config) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Uninstall removes the ctrld-related configurations and services from the Merlin router and reverts to the original state.
|
||||
func (m *Merlin) Uninstall(_ *service.Config) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// PreRun prepares the Merlin instance for operation by waiting for essential services and directories to become available.
|
||||
func (m *Merlin) PreRun() error {
|
||||
// Wait NTP ready.
|
||||
_ = m.Cleanup()
|
||||
@@ -65,6 +79,7 @@ func (m *Merlin) PreRun() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Setup initializes and configures the Merlin instance for use, including setting up dnsmasq and necessary nvram settings.
|
||||
func (m *Merlin) Setup() error {
|
||||
if m.cfg.FirstListener().IsDirectDnsListener() {
|
||||
return nil
|
||||
@@ -73,30 +88,17 @@ func (m *Merlin) Setup() error {
|
||||
if val, _ := nvram.Run("get", nvram.CtrldSetupKey); val == "1" {
|
||||
return nil
|
||||
}
|
||||
buf, err := os.ReadFile(dnsmasq.MerlinPostConfPath)
|
||||
// Already setup.
|
||||
if bytes.Contains(buf, []byte(dnsmasq.MerlinPostConfMarker)) {
|
||||
return nil
|
||||
}
|
||||
if err != nil && !os.IsNotExist(err) {
|
||||
|
||||
if err := m.writeDnsmasqPostconf(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
data, err := dnsmasq.ConfTmpl(dnsmasq.MerlinPostConfTmpl, m.cfg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
data = strings.Join([]string{
|
||||
data,
|
||||
"\n",
|
||||
dnsmasq.MerlinPostConfMarker,
|
||||
"\n",
|
||||
string(buf),
|
||||
}, "\n")
|
||||
// Write dnsmasq post conf file.
|
||||
if err := os.WriteFile(dnsmasq.MerlinPostConfPath, []byte(data), 0750); err != nil {
|
||||
return err
|
||||
for _, cfg := range getDnsmasqConfigs() {
|
||||
if err := m.setupDnsmasq(cfg); err != nil {
|
||||
return fmt.Errorf("failed to setup dnsmasq: config: %s, error: %w", cfg.confPath, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Restart dnsmasq service.
|
||||
if err := restartDNSMasq(); err != nil {
|
||||
return err
|
||||
@@ -109,6 +111,7 @@ func (m *Merlin) Setup() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Cleanup restores the original dnsmasq and nvram configurations and restarts dnsmasq if necessary.
|
||||
func (m *Merlin) Cleanup() error {
|
||||
if m.cfg.FirstListener().IsDirectDnsListener() {
|
||||
return nil
|
||||
@@ -130,6 +133,12 @@ func (m *Merlin) Cleanup() error {
|
||||
if err := os.WriteFile(dnsmasq.MerlinPostConfPath, merlinParsePostConf(buf), 0750); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, cfg := range getDnsmasqConfigs() {
|
||||
if err := m.cleanupDnsmasqJffs(cfg); err != nil {
|
||||
return fmt.Errorf("failed to cleanup jffs dnsmasq: config: %s, error: %w", cfg.confPath, err)
|
||||
}
|
||||
}
|
||||
// Restart dnsmasq service.
|
||||
if err := restartDNSMasq(); err != nil {
|
||||
return err
|
||||
@@ -137,6 +146,81 @@ func (m *Merlin) Cleanup() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// setupDnsmasq sets up dnsmasq configuration by writing postconf, copying configuration, and running a postconf script.
|
||||
func (m *Merlin) setupDnsmasq(cfg *dnsmasqConfig) error {
|
||||
src, err := os.Open(cfg.confPath)
|
||||
if os.IsNotExist(err) {
|
||||
return nil // nothing to do if conf file does not exist.
|
||||
}
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to open dnsmasq config: %w", err)
|
||||
}
|
||||
defer src.Close()
|
||||
|
||||
// Copy current dnsmasq config to cfg.jffsConfPath,
|
||||
// Then we will run postconf script on this file.
|
||||
//
|
||||
// Normally, adding postconf script is enough. However, we see
|
||||
// reports on some Merlin devices that postconf scripts does not
|
||||
// work, but manipulating the config directly via /jffs/configs does.
|
||||
dst, err := os.Create(cfg.jffsConfPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create %s: %w", cfg.jffsConfPath, err)
|
||||
}
|
||||
defer dst.Close()
|
||||
|
||||
if _, err := io.Copy(dst, src); err != nil {
|
||||
return fmt.Errorf("failed to copy current dnsmasq config: %w", err)
|
||||
}
|
||||
if err := dst.Close(); err != nil {
|
||||
return fmt.Errorf("failed to save %s: %w", cfg.jffsConfPath, err)
|
||||
}
|
||||
|
||||
// Run postconf script on cfg.jffsConfPath directly.
|
||||
cmd := exec.Command("/bin/sh", dnsmasq.MerlinPostConfPath, cfg.jffsConfPath)
|
||||
if out, err := cmd.CombinedOutput(); err != nil {
|
||||
return fmt.Errorf("failed to run post conf: %s: %w", string(out), err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// cleanupDnsmasqJffs removes the JFFS configuration file specified in the given dnsmasqConfig, if it exists.
|
||||
func (m *Merlin) cleanupDnsmasqJffs(cfg *dnsmasqConfig) error {
|
||||
// Remove cfg.jffsConfPath file.
|
||||
if err := os.Remove(cfg.jffsConfPath); err != nil && !os.IsNotExist(err) {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// writeDnsmasqPostconf writes the requireddnsmasqConfigs post-configuration for dnsmasq to enable custom DNS settings with ctrld.
|
||||
func (m *Merlin) writeDnsmasqPostconf() error {
|
||||
buf, err := os.ReadFile(dnsmasq.MerlinPostConfPath)
|
||||
// Already setup.
|
||||
if bytes.Contains(buf, []byte(dnsmasq.MerlinPostConfMarker)) {
|
||||
return nil
|
||||
}
|
||||
if err != nil && !os.IsNotExist(err) {
|
||||
return err
|
||||
}
|
||||
|
||||
data, err := dnsmasq.ConfTmpl(dnsmasq.MerlinPostConfTmpl, m.cfg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
data = strings.Join([]string{
|
||||
data,
|
||||
"\n",
|
||||
dnsmasq.MerlinPostConfMarker,
|
||||
"\n",
|
||||
string(buf),
|
||||
}, "\n")
|
||||
// Write dnsmasq post conf file.
|
||||
return os.WriteFile(dnsmasq.MerlinPostConfPath, []byte(data), 0750)
|
||||
}
|
||||
|
||||
// restartDNSMasq restarts the dnsmasq service by executing the appropriate system command using "service".
|
||||
// Returns an error if the command fails or if there is an issue processing the command output.
|
||||
func restartDNSMasq() error {
|
||||
if out, err := exec.Command("service", "restart_dnsmasq").CombinedOutput(); err != nil {
|
||||
return fmt.Errorf("restart_dnsmasq: %s, %w", string(out), err)
|
||||
@@ -144,6 +228,22 @@ func restartDNSMasq() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// getDnsmasqConfigs retrieves a list of dnsmasqConfig containing configuration and JFFS paths for dnsmasq operations.
|
||||
func getDnsmasqConfigs() []*dnsmasqConfig {
|
||||
cfgs := []*dnsmasqConfig{
|
||||
{dnsmasq.MerlinConfPath, dnsmasq.MerlinJffsConfPath},
|
||||
}
|
||||
for _, path := range dnsmasq.AdditionalConfigFiles() {
|
||||
jffsConfPath := filepath.Join(dnsmasq.MerlinJffsConfDir, filepath.Base(path))
|
||||
cfgs = append(cfgs, &dnsmasqConfig{path, jffsConfPath})
|
||||
}
|
||||
|
||||
return cfgs
|
||||
}
|
||||
|
||||
// merlinParsePostConf parses the dnsmasq post configuration by removing content after the MerlinPostConfMarker, if present.
|
||||
// If no marker is found, the original buffer is returned unmodified.
|
||||
// Returns nil if the input buffer is empty.
|
||||
func merlinParsePostConf(buf []byte) []byte {
|
||||
if len(buf) == 0 {
|
||||
return nil
|
||||
@@ -155,6 +255,7 @@ func merlinParsePostConf(buf []byte) []byte {
|
||||
return buf
|
||||
}
|
||||
|
||||
// waitDirExists waits until the specified directory exists, polling its existence every second.
|
||||
func waitDirExists(dir string) {
|
||||
for {
|
||||
if _, err := os.Stat(dir); !os.IsNotExist(err) {
|
||||
|
||||
@@ -2,10 +2,13 @@ package openwrt
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/kardianos/service"
|
||||
@@ -15,10 +18,13 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
Name = "openwrt"
|
||||
openwrtDNSMasqConfigPath = "/tmp/dnsmasq.d/ctrld.conf"
|
||||
Name = "openwrt"
|
||||
openwrtDNSMasqConfigName = "ctrld.conf"
|
||||
openwrtDNSMasqDefaultConfigDir = "/tmp/dnsmasq.d"
|
||||
)
|
||||
|
||||
var openwrtDnsmasqDefaultConfigPath = filepath.Join(openwrtDNSMasqDefaultConfigDir, openwrtDNSMasqConfigName)
|
||||
|
||||
type Openwrt struct {
|
||||
cfg *ctrld.Config
|
||||
dnsmasqCacheSize string
|
||||
@@ -67,7 +73,7 @@ func (o *Openwrt) Setup() error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := os.WriteFile(openwrtDNSMasqConfigPath, []byte(data), 0600); err != nil {
|
||||
if err := os.WriteFile(dnsmasqConfPathFromUbus(), []byte(data), 0600); err != nil {
|
||||
return err
|
||||
}
|
||||
// Restart dnsmasq service.
|
||||
@@ -82,7 +88,7 @@ func (o *Openwrt) Cleanup() error {
|
||||
return nil
|
||||
}
|
||||
// Remove the custom dnsmasq config
|
||||
if err := os.Remove(openwrtDNSMasqConfigPath); err != nil {
|
||||
if err := os.Remove(dnsmasqConfPathFromUbus()); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -126,3 +132,60 @@ func uci(args ...string) (string, error) {
|
||||
}
|
||||
return strings.TrimSpace(stdout.String()), nil
|
||||
}
|
||||
|
||||
// openwrtServiceList represents openwrt services config.
|
||||
type openwrtServiceList struct {
|
||||
Dnsmasq dnsmasqConf `json:"dnsmasq"`
|
||||
}
|
||||
|
||||
// dnsmasqConf represents dnsmasq config.
|
||||
type dnsmasqConf struct {
|
||||
Instances map[string]confInstances `json:"instances"`
|
||||
}
|
||||
|
||||
// confInstances represents an instance config of a service.
|
||||
type confInstances struct {
|
||||
Mount map[string]string `json:"mount"`
|
||||
}
|
||||
|
||||
// dnsmasqConfPath returns the dnsmasq config path.
|
||||
//
|
||||
// Since version 24.10, openwrt makes some changes to dnsmasq to support
|
||||
// multiple instances of dnsmasq. This change causes breaking changes to
|
||||
// software which depends on the default dnsmasq path.
|
||||
//
|
||||
// There are some discussion/PRs in openwrt repo to address this:
|
||||
//
|
||||
// - https://github.com/openwrt/openwrt/pull/16806
|
||||
// - https://github.com/openwrt/openwrt/pull/16890
|
||||
//
|
||||
// In the meantime, workaround this problem by querying the actual config path
|
||||
// by querying ubus service list.
|
||||
func dnsmasqConfPath(r io.Reader) string {
|
||||
var svc openwrtServiceList
|
||||
if err := json.NewDecoder(r).Decode(&svc); err != nil {
|
||||
return openwrtDnsmasqDefaultConfigPath
|
||||
}
|
||||
for _, inst := range svc.Dnsmasq.Instances {
|
||||
for mount := range inst.Mount {
|
||||
dirName := filepath.Base(mount)
|
||||
parts := strings.Split(dirName, ".")
|
||||
if len(parts) < 2 {
|
||||
continue
|
||||
}
|
||||
if parts[0] == "dnsmasq" && parts[len(parts)-1] == "d" {
|
||||
return filepath.Join(mount, openwrtDNSMasqConfigName)
|
||||
}
|
||||
}
|
||||
}
|
||||
return openwrtDnsmasqDefaultConfigPath
|
||||
}
|
||||
|
||||
// dnsmasqConfPathFromUbus get dnsmasq config path from ubus service list.
|
||||
func dnsmasqConfPathFromUbus() string {
|
||||
output, err := exec.Command("ubus", "call", "service", "list").Output()
|
||||
if err != nil {
|
||||
return openwrtDnsmasqDefaultConfigPath
|
||||
}
|
||||
return dnsmasqConfPath(bytes.NewReader(output))
|
||||
}
|
||||
|
||||
58
internal/router/openwrt/openwrt_test.go
Normal file
58
internal/router/openwrt/openwrt_test.go
Normal file
@@ -0,0 +1,58 @@
|
||||
package openwrt
|
||||
|
||||
import (
|
||||
"io"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// Sample output from https://github.com/openwrt/openwrt/pull/16806#issuecomment-2448255734
|
||||
const ubusDnsmasqBefore2410 = `{
|
||||
"dnsmasq": {
|
||||
"instances": {
|
||||
"guest_dns": {
|
||||
"mount": {
|
||||
"/tmp/dnsmasq.d": "0",
|
||||
"/var/run/dnsmasq/": "1"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}`
|
||||
|
||||
const ubusDnsmasq2410 = `{
|
||||
"dnsmasq": {
|
||||
"instances": {
|
||||
"guest_dns": {
|
||||
"mount": {
|
||||
"/tmp/dnsmasq.guest_dns.d": "0",
|
||||
"/var/run/dnsmasq/": "1"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}`
|
||||
|
||||
func Test_dnsmasqConfPath(t *testing.T) {
|
||||
var dnsmasq2410expected = filepath.Join("/tmp/dnsmasq.guest_dns.d", openwrtDNSMasqConfigName)
|
||||
tests := []struct {
|
||||
name string
|
||||
in io.Reader
|
||||
expected string
|
||||
}{
|
||||
{"empty", strings.NewReader(""), openwrtDnsmasqDefaultConfigPath},
|
||||
{"invalid", strings.NewReader("}}"), openwrtDnsmasqDefaultConfigPath},
|
||||
{"before 24.10", strings.NewReader(ubusDnsmasqBefore2410), openwrtDnsmasqDefaultConfigPath},
|
||||
{"24.10", strings.NewReader(ubusDnsmasq2410), dnsmasq2410expected},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
if got := dnsmasqConfPath(tc.in); got != tc.expected {
|
||||
t.Errorf("dnsmasqConfPath() = %v, want %v", got, tc.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -215,6 +215,20 @@ func LeaseFilesDir() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
// ServiceDependencies returns list of dependencies that ctrld services needs on this router.
|
||||
// See https://pkg.go.dev/github.com/kardianos/service#Config for list format.
|
||||
func ServiceDependencies() []string {
|
||||
if Name() == ubios.Name {
|
||||
// On Ubios, ctrld needs to start after unifi-mongodb,
|
||||
// so it can query custom client info mapping.
|
||||
return []string{
|
||||
"Wants=unifi-mongodb.service",
|
||||
"After=unifi-mongodb.service",
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func distroName() string {
|
||||
switch {
|
||||
case bytes.HasPrefix(unameO(), []byte("DD-WRT")):
|
||||
|
||||
@@ -45,11 +45,15 @@ func (s *tomatoSvc) Platform() string {
|
||||
}
|
||||
|
||||
func (s *tomatoSvc) configPath() string {
|
||||
path, err := os.Executable()
|
||||
if err != nil {
|
||||
return ""
|
||||
bin := s.Config.Executable
|
||||
if bin == "" {
|
||||
path, err := os.Executable()
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
bin = path
|
||||
}
|
||||
return path + ".startup"
|
||||
return bin + ".startup"
|
||||
}
|
||||
|
||||
func (s *tomatoSvc) template() *template.Template {
|
||||
|
||||
@@ -13,14 +13,13 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/kardianos/service"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld/internal/router/dnsmasq"
|
||||
)
|
||||
|
||||
// This is a copy of https://github.com/kardianos/service/blob/v1.2.1/service_sysv_linux.go,
|
||||
// with modification for supporting ubios v1 init system.
|
||||
|
||||
// Keep in sync with ubios.ubiosDNSMasqConfigPath
|
||||
const ubiosDNSMasqConfigPath = "/run/dnsmasq.conf.d/zzzctrld.conf"
|
||||
|
||||
type ubiosSvc struct {
|
||||
i service.Interface
|
||||
platform string
|
||||
@@ -86,7 +85,7 @@ func (s *ubiosSvc) Install() error {
|
||||
}{
|
||||
s.Config,
|
||||
path,
|
||||
ubiosDNSMasqConfigPath,
|
||||
filepath.Join(dnsmasq.UbiosConfPath(), dnsmasq.UbiosConfName),
|
||||
}
|
||||
|
||||
if err := s.template().Execute(f, to); err != nil {
|
||||
@@ -219,6 +218,8 @@ const ubiosBootSystemdService = `[Unit]
|
||||
Description=Run ctrld On Startup UDM
|
||||
Wants=network-online.target
|
||||
After=network-online.target
|
||||
Wants=unifi-mongodb
|
||||
After=unifi-mongodb
|
||||
StartLimitIntervalSec=500
|
||||
StartLimitBurst=5
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@ package ubios
|
||||
import (
|
||||
"bytes"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
|
||||
"github.com/kardianos/service"
|
||||
@@ -12,19 +13,19 @@ import (
|
||||
"github.com/Control-D-Inc/ctrld/internal/router/edgeos"
|
||||
)
|
||||
|
||||
const (
|
||||
Name = "ubios"
|
||||
ubiosDNSMasqConfigPath = "/run/dnsmasq.conf.d/zzzctrld.conf"
|
||||
ubiosDNSMasqDnsConfigPath = "/run/dnsmasq.conf.d/dns.conf"
|
||||
)
|
||||
const Name = "ubios"
|
||||
|
||||
type Ubios struct {
|
||||
cfg *ctrld.Config
|
||||
cfg *ctrld.Config
|
||||
dnsmasqConfPath string
|
||||
}
|
||||
|
||||
// New returns a router.Router for configuring/setup/run ctrld on Ubios routers.
|
||||
func New(cfg *ctrld.Config) *Ubios {
|
||||
return &Ubios{cfg: cfg}
|
||||
return &Ubios{
|
||||
cfg: cfg,
|
||||
dnsmasqConfPath: filepath.Join(dnsmasq.UbiosConfPath(), dnsmasq.UbiosConfName),
|
||||
}
|
||||
}
|
||||
|
||||
func (u *Ubios) ConfigureService(config *service.Config) error {
|
||||
@@ -59,7 +60,7 @@ func (u *Ubios) Setup() error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := os.WriteFile(ubiosDNSMasqConfigPath, []byte(data), 0600); err != nil {
|
||||
if err := os.WriteFile(u.dnsmasqConfPath, []byte(data), 0600); err != nil {
|
||||
return err
|
||||
}
|
||||
// Restart dnsmasq service.
|
||||
@@ -74,7 +75,7 @@ func (u *Ubios) Cleanup() error {
|
||||
return nil
|
||||
}
|
||||
// Remove the custom dnsmasq config
|
||||
if err := os.Remove(ubiosDNSMasqConfigPath); err != nil {
|
||||
if err := os.Remove(u.dnsmasqConfPath); err != nil {
|
||||
return err
|
||||
}
|
||||
// Restart dnsmasq service.
|
||||
@@ -85,7 +86,7 @@ func (u *Ubios) Cleanup() error {
|
||||
}
|
||||
|
||||
func restartDNSMasq() error {
|
||||
buf, err := os.ReadFile("/run/dnsmasq.pid")
|
||||
buf, err := os.ReadFile(dnsmasq.UbiosPidFile())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
5
log.go
5
log.go
@@ -9,11 +9,6 @@ import (
|
||||
"github.com/rs/zerolog"
|
||||
)
|
||||
|
||||
func init() {
|
||||
l := zerolog.New(io.Discard)
|
||||
ProxyLogger.Store(&l)
|
||||
}
|
||||
|
||||
// ProxyLog emits the log record for proxy operations.
|
||||
// The caller should set it only once.
|
||||
// DEPRECATED: use ProxyLogger instead.
|
||||
|
||||
@@ -1,19 +1,16 @@
|
||||
//go:build darwin || dragonfly || freebsd || netbsd || openbsd
|
||||
//go:build dragonfly || freebsd || netbsd || openbsd
|
||||
|
||||
package ctrld
|
||||
|
||||
import (
|
||||
"net"
|
||||
"os/exec"
|
||||
"runtime"
|
||||
"strings"
|
||||
"syscall"
|
||||
|
||||
"golang.org/x/net/route"
|
||||
)
|
||||
|
||||
func dnsFns() []dnsFn {
|
||||
return []dnsFn{dnsFromRIB, dnsFromIPConfig}
|
||||
return []dnsFn{dnsFromResolvConf, dnsFromRIB}
|
||||
}
|
||||
|
||||
func dnsFromRIB() []string {
|
||||
@@ -49,18 +46,6 @@ func dnsFromRIB() []string {
|
||||
return dns
|
||||
}
|
||||
|
||||
func dnsFromIPConfig() []string {
|
||||
if runtime.GOOS != "darwin" {
|
||||
return nil
|
||||
}
|
||||
cmd := exec.Command("ipconfig", "getoption", "", "domain_name_server")
|
||||
out, _ := cmd.Output()
|
||||
if ip := net.ParseIP(strings.TrimSpace(string(out))); ip != nil {
|
||||
return []string{ip.String()}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func toNetIP(addr route.Addr) net.IP {
|
||||
switch t := addr.(type) {
|
||||
case *route.Inet4Addr:
|
||||
|
||||
236
nameservers_darwin.go
Normal file
236
nameservers_darwin.go
Normal file
@@ -0,0 +1,236 @@
|
||||
//go:build darwin
|
||||
|
||||
package ctrld
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"os/exec"
|
||||
"regexp"
|
||||
"slices"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"tailscale.com/net/netmon"
|
||||
)
|
||||
|
||||
func dnsFns() []dnsFn {
|
||||
return []dnsFn{dnsFromResolvConf, getDNSFromScutil, getAllDHCPNameservers}
|
||||
}
|
||||
|
||||
func getDNSFromScutil() []string {
|
||||
logger := *ProxyLogger.Load()
|
||||
|
||||
const (
|
||||
maxRetries = 10
|
||||
retryInterval = 100 * time.Millisecond
|
||||
)
|
||||
|
||||
regularIPs, loopbackIPs, _ := netmon.LocalAddresses()
|
||||
|
||||
var nameservers []string
|
||||
for attempt := 0; attempt < maxRetries; attempt++ {
|
||||
if attempt > 0 {
|
||||
time.Sleep(retryInterval)
|
||||
}
|
||||
|
||||
cmd := exec.Command("scutil", "--dns")
|
||||
output, err := cmd.Output()
|
||||
if err != nil {
|
||||
Log(context.Background(), logger.Error(), "failed to execute scutil --dns (attempt %d/%d): %v", attempt+1, maxRetries, err)
|
||||
continue
|
||||
}
|
||||
|
||||
var localDNS []string
|
||||
seen := make(map[string]bool)
|
||||
|
||||
scanner := bufio.NewScanner(bytes.NewReader(output))
|
||||
for scanner.Scan() {
|
||||
line := strings.TrimSpace(scanner.Text())
|
||||
if strings.HasPrefix(line, "nameserver[") {
|
||||
parts := strings.Split(line, ":")
|
||||
if len(parts) == 2 {
|
||||
ns := strings.TrimSpace(parts[1])
|
||||
if ip := net.ParseIP(ns); ip != nil {
|
||||
// skip loopback IPs
|
||||
isLocal := false
|
||||
for _, v := range slices.Concat(regularIPs, loopbackIPs) {
|
||||
ipStr := v.String()
|
||||
if ip.String() == ipStr {
|
||||
isLocal = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !isLocal && !seen[ip.String()] {
|
||||
seen[ip.String()] = true
|
||||
localDNS = append(localDNS, ip.String())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
Log(context.Background(), logger.Error(), "error scanning scutil output (attempt %d/%d): %v", attempt+1, maxRetries, err)
|
||||
continue
|
||||
}
|
||||
|
||||
// If we successfully read the output and found nameservers, return them
|
||||
if len(localDNS) > 0 {
|
||||
return localDNS
|
||||
}
|
||||
}
|
||||
|
||||
return nameservers
|
||||
}
|
||||
|
||||
func getDHCPNameservers(iface string) ([]string, error) {
|
||||
// Run the ipconfig command for the given interface.
|
||||
cmd := exec.Command("ipconfig", "getpacket", iface)
|
||||
output, err := cmd.Output()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error running ipconfig: %v", err)
|
||||
}
|
||||
|
||||
// Look for a line like:
|
||||
// domain_name_servers = 192.168.1.1 8.8.8.8;
|
||||
re := regexp.MustCompile(`domain_name_servers\s*=\s*(.*);`)
|
||||
matches := re.FindStringSubmatch(string(output))
|
||||
if len(matches) < 2 {
|
||||
return nil, fmt.Errorf("no DHCP nameservers found")
|
||||
}
|
||||
|
||||
// Split the nameservers by whitespace.
|
||||
nameservers := strings.Fields(matches[1])
|
||||
return nameservers, nil
|
||||
}
|
||||
|
||||
func getAllDHCPNameservers() []string {
|
||||
logger := *ProxyLogger.Load()
|
||||
|
||||
interfaces, err := net.Interfaces()
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
regularIPs, loopbackIPs, _ := netmon.LocalAddresses()
|
||||
|
||||
var allNameservers []string
|
||||
seen := make(map[string]bool)
|
||||
|
||||
for _, iface := range interfaces {
|
||||
// Skip interfaces that are:
|
||||
// - down
|
||||
// - loopback
|
||||
// - not physical (virtual)
|
||||
// - point-to-point (like VPN interfaces)
|
||||
// - without MAC address (non-physical)
|
||||
if iface.Flags&net.FlagUp == 0 ||
|
||||
iface.Flags&net.FlagLoopback != 0 ||
|
||||
iface.Flags&net.FlagPointToPoint != 0 ||
|
||||
(iface.Flags&net.FlagBroadcast == 0 &&
|
||||
iface.Flags&net.FlagMulticast == 0) ||
|
||||
len(iface.HardwareAddr) == 0 ||
|
||||
strings.HasPrefix(iface.Name, "utun") ||
|
||||
strings.HasPrefix(iface.Name, "llw") ||
|
||||
strings.HasPrefix(iface.Name, "awdl") {
|
||||
continue
|
||||
}
|
||||
|
||||
// Verify it's a valid MAC address (should be 6 bytes for IEEE 802 MAC-48)
|
||||
if len(iface.HardwareAddr) != 6 {
|
||||
continue
|
||||
}
|
||||
|
||||
nameservers, err := getDHCPNameservers(iface.Name)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// Add unique nameservers to the result, skipping local IPs
|
||||
for _, ns := range nameservers {
|
||||
if ip := net.ParseIP(ns); ip != nil {
|
||||
// skip loopback and local IPs
|
||||
isLocal := false
|
||||
for _, v := range slices.Concat(regularIPs, loopbackIPs) {
|
||||
if ip.String() == v.String() {
|
||||
isLocal = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !isLocal && !seen[ns] {
|
||||
seen[ns] = true
|
||||
allNameservers = append(allNameservers, ns)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// if we have static DNS servers saved for the current default route, we should add them to the list
|
||||
drIfaceName, err := netmon.DefaultRouteInterface()
|
||||
Log(context.Background(), logger.Debug(), "checking for static DNS servers for default route interface: %s", drIfaceName)
|
||||
if err != nil {
|
||||
Log(context.Background(), logger.Debug(),
|
||||
"Failed to get default route interface: %v", err)
|
||||
} else {
|
||||
drIface, err := net.InterfaceByName(drIfaceName)
|
||||
if err != nil {
|
||||
Log(context.Background(), logger.Debug(),
|
||||
"Failed to get interface by name %s: %v", drIfaceName, err)
|
||||
} else if drIface != nil {
|
||||
if _, err := patchNetIfaceName(drIface); err != nil {
|
||||
Log(context.Background(), logger.Debug(),
|
||||
"Failed to patch interface name %s: %v", drIfaceName, err)
|
||||
}
|
||||
staticNs, file := SavedStaticNameservers(drIface)
|
||||
Log(context.Background(), logger.Debug(),
|
||||
"static dns servers from %s: %v", file, staticNs)
|
||||
if len(staticNs) > 0 {
|
||||
Log(context.Background(), logger.Debug(),
|
||||
"Adding static DNS servers from %s: %v", drIface.Name, staticNs)
|
||||
allNameservers = append(allNameservers, staticNs...)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return allNameservers
|
||||
}
|
||||
|
||||
func patchNetIfaceName(iface *net.Interface) (bool, error) {
|
||||
b, err := exec.Command("networksetup", "-listnetworkserviceorder").Output()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
patched := false
|
||||
if name := networkServiceName(iface.Name, bytes.NewReader(b)); name != "" {
|
||||
patched = true
|
||||
iface.Name = name
|
||||
}
|
||||
return patched, nil
|
||||
}
|
||||
|
||||
func networkServiceName(ifaceName string, r io.Reader) string {
|
||||
scanner := bufio.NewScanner(r)
|
||||
prevLine := ""
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if strings.Contains(line, "*") {
|
||||
// Network services is disabled.
|
||||
continue
|
||||
}
|
||||
if !strings.Contains(line, "Device: "+ifaceName) {
|
||||
prevLine = line
|
||||
continue
|
||||
}
|
||||
parts := strings.SplitN(prevLine, " ", 2)
|
||||
if len(parts) == 2 {
|
||||
return strings.TrimSpace(parts[1])
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
@@ -5,9 +5,12 @@ import (
|
||||
"bytes"
|
||||
"encoding/hex"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"tailscale.com/net/netmon"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld/internal/dns/resolvconffile"
|
||||
)
|
||||
|
||||
@@ -17,7 +20,7 @@ const (
|
||||
)
|
||||
|
||||
func dnsFns() []dnsFn {
|
||||
return []dnsFn{dns4, dns6, dnsFromSystemdResolver}
|
||||
return []dnsFn{dnsFromResolvConf, dns4, dns6, dnsFromSystemdResolver}
|
||||
}
|
||||
|
||||
func dns4() []string {
|
||||
@@ -128,3 +131,25 @@ func virtualInterfaces() set {
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
// validInterfacesMap returns a set containing non virtual interfaces.
|
||||
// TODO: deduplicated with cmd/cli/net_linux.go in v2.
|
||||
func validInterfaces() set {
|
||||
m := make(map[string]struct{})
|
||||
vis := virtualInterfaces()
|
||||
netmon.ForeachInterface(func(i netmon.Interface, prefixes []netip.Prefix) {
|
||||
if _, existed := vis[i.Name]; existed {
|
||||
return
|
||||
}
|
||||
m[i.Name] = struct{}{}
|
||||
})
|
||||
// Fallback to default route interface if found nothing.
|
||||
if len(m) == 0 {
|
||||
defaultRoute, err := netmon.DefaultRoute()
|
||||
if err != nil {
|
||||
return m
|
||||
}
|
||||
m[defaultRoute.InterfaceName] = struct{}{}
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
@@ -2,8 +2,63 @@
|
||||
|
||||
package ctrld
|
||||
|
||||
import "github.com/Control-D-Inc/ctrld/internal/resolvconffile"
|
||||
import (
|
||||
"net"
|
||||
"slices"
|
||||
"time"
|
||||
|
||||
func nameserversFromResolvconf() []string {
|
||||
return resolvconffile.NameServers("")
|
||||
"tailscale.com/net/netmon"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld/internal/resolvconffile"
|
||||
)
|
||||
|
||||
// currentNameserversFromResolvconf returns the current nameservers set from /etc/resolv.conf file.
|
||||
func currentNameserversFromResolvconf() []string {
|
||||
return resolvconffile.NameServers()
|
||||
}
|
||||
|
||||
// dnsFromResolvConf reads usable nameservers from /etc/resolv.conf file.
|
||||
// A nameserver is usable if it's not one of current machine's IP addresses
|
||||
// and loopback IP addresses.
|
||||
func dnsFromResolvConf() []string {
|
||||
const (
|
||||
maxRetries = 10
|
||||
retryInterval = 100 * time.Millisecond
|
||||
)
|
||||
|
||||
regularIPs, loopbackIPs, _ := netmon.LocalAddresses()
|
||||
|
||||
var dns []string
|
||||
for attempt := 0; attempt < maxRetries; attempt++ {
|
||||
if attempt > 0 {
|
||||
time.Sleep(retryInterval)
|
||||
}
|
||||
|
||||
nss := resolvconffile.NameServers()
|
||||
var localDNS []string
|
||||
seen := make(map[string]bool)
|
||||
|
||||
for _, ns := range nss {
|
||||
if ip := net.ParseIP(ns); ip != nil {
|
||||
// skip loopback IPs
|
||||
for _, v := range slices.Concat(regularIPs, loopbackIPs) {
|
||||
ipStr := v.String()
|
||||
if ip.String() == ipStr {
|
||||
continue
|
||||
}
|
||||
}
|
||||
if !seen[ip.String()] {
|
||||
seen[ip.String()] = true
|
||||
localDNS = append(localDNS, ip.String())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If we successfully read the file and found nameservers, return them
|
||||
if len(localDNS) > 0 {
|
||||
return localDNS
|
||||
}
|
||||
}
|
||||
|
||||
return dns
|
||||
}
|
||||
|
||||
@@ -1,44 +1,484 @@
|
||||
package ctrld
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"os"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
"unsafe"
|
||||
|
||||
"github.com/microsoft/wmi/pkg/base/host"
|
||||
"github.com/microsoft/wmi/pkg/base/instance"
|
||||
"github.com/microsoft/wmi/pkg/base/query"
|
||||
"github.com/microsoft/wmi/pkg/constant"
|
||||
"github.com/microsoft/wmi/pkg/hardware/network/netadapter"
|
||||
"golang.org/x/sys/windows"
|
||||
"golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
|
||||
"tailscale.com/net/netmon"
|
||||
)
|
||||
|
||||
const (
|
||||
maxDNSAdapterRetries = 5
|
||||
retryDelayDNSAdapter = 1 * time.Second
|
||||
defaultDNSAdapterTimeout = 10 * time.Second
|
||||
minDNSServers = 1 // Minimum number of DNS servers we want to find
|
||||
|
||||
DS_FORCE_REDISCOVERY = 0x00000001
|
||||
DS_DIRECTORY_SERVICE_REQUIRED = 0x00000010
|
||||
DS_BACKGROUND_ONLY = 0x00000100
|
||||
DS_IP_REQUIRED = 0x00000200
|
||||
DS_IS_DNS_NAME = 0x00020000
|
||||
DS_RETURN_DNS_NAME = 0x40000000
|
||||
)
|
||||
|
||||
type DomainControllerInfo struct {
|
||||
DomainControllerName *uint16
|
||||
DomainControllerAddress *uint16
|
||||
DomainControllerAddressType uint32
|
||||
DomainGuid windows.GUID
|
||||
DomainName *uint16
|
||||
DnsForestName *uint16
|
||||
Flags uint32
|
||||
DcSiteName *uint16
|
||||
ClientSiteName *uint16
|
||||
}
|
||||
|
||||
func dnsFns() []dnsFn {
|
||||
return []dnsFn{dnsFromAdapter}
|
||||
}
|
||||
|
||||
func dnsFromAdapter() []string {
|
||||
aas, err := winipcfg.GetAdaptersAddresses(syscall.AF_UNSPEC, winipcfg.GAAFlagIncludeGateways|winipcfg.GAAFlagIncludePrefix)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
ns := make([]string, 0, len(aas)*2)
|
||||
seen := make(map[string]bool)
|
||||
addressMap := make(map[string]struct{})
|
||||
for _, aa := range aas {
|
||||
for a := aa.FirstUnicastAddress; a != nil; a = a.Next {
|
||||
addressMap[a.Address.IP().String()] = struct{}{}
|
||||
}
|
||||
}
|
||||
for _, aa := range aas {
|
||||
for dns := aa.FirstDNSServerAddress; dns != nil; dns = dns.Next {
|
||||
ip := dns.Address.IP()
|
||||
if ip == nil || ip.IsLoopback() || seen[ip.String()] {
|
||||
continue
|
||||
}
|
||||
if _, ok := addressMap[ip.String()]; ok {
|
||||
continue
|
||||
}
|
||||
seen[ip.String()] = true
|
||||
ns = append(ns, ip.String())
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultDNSAdapterTimeout)
|
||||
defer cancel()
|
||||
|
||||
var ns []string
|
||||
var err error
|
||||
|
||||
logger := *ProxyLogger.Load()
|
||||
|
||||
for i := 0; i < maxDNSAdapterRetries; i++ {
|
||||
if ctx.Err() != nil {
|
||||
Log(context.Background(), logger.Debug(),
|
||||
"dnsFromAdapter lookup cancelled or timed out, attempt %d", i)
|
||||
return nil
|
||||
}
|
||||
|
||||
ns, err = getDNSServers(ctx)
|
||||
if err == nil && len(ns) >= minDNSServers {
|
||||
if i > 0 {
|
||||
Log(context.Background(), logger.Debug(),
|
||||
"Successfully got DNS servers after %d attempts, found %d servers",
|
||||
i+1, len(ns))
|
||||
}
|
||||
return ns
|
||||
}
|
||||
|
||||
// if osResolver is not initialized, this is likely a command line run
|
||||
// and ctrld is already on the interface, abort retries
|
||||
if or == nil {
|
||||
return ns
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
Log(context.Background(), logger.Debug(),
|
||||
"Failed to get DNS servers, attempt %d: %v", i+1, err)
|
||||
} else {
|
||||
Log(context.Background(), logger.Debug(),
|
||||
"Got insufficient DNS servers, retrying, found %d servers", len(ns))
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil
|
||||
case <-time.After(retryDelayDNSAdapter):
|
||||
}
|
||||
}
|
||||
|
||||
Log(context.Background(), logger.Debug(),
|
||||
"Failed to get sufficient DNS servers after all attempts, max_retries=%d", maxDNSAdapterRetries)
|
||||
return ns
|
||||
}
|
||||
|
||||
func nameserversFromResolvconf() []string {
|
||||
func getDNSServers(ctx context.Context) ([]string, error) {
|
||||
logger := *ProxyLogger.Load()
|
||||
|
||||
// Check context before making the call
|
||||
if ctx.Err() != nil {
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
|
||||
// Get DNS servers from adapters (existing method)
|
||||
flags := winipcfg.GAAFlagIncludeGateways |
|
||||
winipcfg.GAAFlagIncludePrefix
|
||||
|
||||
aas, err := winipcfg.GetAdaptersAddresses(syscall.AF_UNSPEC, flags)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("getting adapters: %w", err)
|
||||
}
|
||||
|
||||
Log(context.Background(), logger.Debug(),
|
||||
"Found network adapters, count=%d", len(aas))
|
||||
|
||||
// Try to get domain controller info if domain-joined
|
||||
var dcServers []string
|
||||
isDomain := checkDomainJoined()
|
||||
if isDomain {
|
||||
domainName, err := getLocalADDomain()
|
||||
if err != nil {
|
||||
Log(context.Background(), logger.Debug(),
|
||||
"Failed to get local AD domain: %v", err)
|
||||
} else {
|
||||
// Load netapi32.dll
|
||||
netapi32 := windows.NewLazySystemDLL("netapi32.dll")
|
||||
dsDcName := netapi32.NewProc("DsGetDcNameW")
|
||||
|
||||
var info *DomainControllerInfo
|
||||
flags := uint32(DS_RETURN_DNS_NAME | DS_IP_REQUIRED | DS_IS_DNS_NAME)
|
||||
|
||||
domainUTF16, err := windows.UTF16PtrFromString(domainName)
|
||||
if err != nil {
|
||||
Log(context.Background(), logger.Debug(),
|
||||
"Failed to convert domain name to UTF16: %v", err)
|
||||
} else {
|
||||
Log(context.Background(), logger.Debug(),
|
||||
"Attempting to get DC for domain: %s with flags: 0x%x", domainName, flags)
|
||||
|
||||
// Call DsGetDcNameW with domain name
|
||||
ret, _, err := dsDcName.Call(
|
||||
0, // ComputerName - can be NULL
|
||||
uintptr(unsafe.Pointer(domainUTF16)), // DomainName
|
||||
0, // DomainGuid - not needed
|
||||
0, // SiteName - not needed
|
||||
uintptr(flags), // Flags
|
||||
uintptr(unsafe.Pointer(&info))) // DomainControllerInfo - output
|
||||
|
||||
if ret != 0 {
|
||||
switch ret {
|
||||
case 1355: // ERROR_NO_SUCH_DOMAIN
|
||||
Log(context.Background(), logger.Debug(),
|
||||
"Domain not found: %s (%d)", domainName, ret)
|
||||
case 1311: // ERROR_NO_LOGON_SERVERS
|
||||
Log(context.Background(), logger.Debug(),
|
||||
"No logon servers available for domain: %s (%d)", domainName, ret)
|
||||
case 1004: // ERROR_DC_NOT_FOUND
|
||||
Log(context.Background(), logger.Debug(),
|
||||
"Domain controller not found for domain: %s (%d)", domainName, ret)
|
||||
case 1722: // RPC_S_SERVER_UNAVAILABLE
|
||||
Log(context.Background(), logger.Debug(),
|
||||
"RPC server unavailable for domain: %s (%d)", domainName, ret)
|
||||
default:
|
||||
Log(context.Background(), logger.Debug(),
|
||||
"Failed to get domain controller info for domain %s: %d, %v", domainName, ret, err)
|
||||
}
|
||||
} else if info != nil {
|
||||
defer windows.NetApiBufferFree((*byte)(unsafe.Pointer(info)))
|
||||
|
||||
if info.DomainControllerAddress != nil {
|
||||
dcAddr := windows.UTF16PtrToString(info.DomainControllerAddress)
|
||||
dcAddr = strings.TrimPrefix(dcAddr, "\\\\")
|
||||
Log(context.Background(), logger.Debug(),
|
||||
"Found domain controller address: %s", dcAddr)
|
||||
|
||||
if ip := net.ParseIP(dcAddr); ip != nil {
|
||||
dcServers = append(dcServers, ip.String())
|
||||
Log(context.Background(), logger.Debug(),
|
||||
"Added domain controller DNS servers: %v", dcServers)
|
||||
}
|
||||
} else {
|
||||
Log(context.Background(), logger.Debug(),
|
||||
"No domain controller address found")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Continue with existing adapter DNS collection
|
||||
ns := make([]string, 0, len(aas)*2)
|
||||
seen := make(map[string]bool)
|
||||
addressMap := make(map[string]struct{})
|
||||
|
||||
// Collect all local IPs
|
||||
for _, aa := range aas {
|
||||
if aa.OperStatus != winipcfg.IfOperStatusUp {
|
||||
Log(context.Background(), logger.Debug(),
|
||||
"Skipping adapter %s - not up, status: %d", aa.FriendlyName(), aa.OperStatus)
|
||||
continue
|
||||
}
|
||||
|
||||
// Skip if software loopback or other non-physical types
|
||||
// This is to avoid the "Loopback Pseudo-Interface 1" issue we see on windows
|
||||
if aa.IfType == winipcfg.IfTypeSoftwareLoopback {
|
||||
Log(context.Background(), logger.Debug(),
|
||||
"Skipping %s (software loopback)", aa.FriendlyName())
|
||||
continue
|
||||
}
|
||||
|
||||
Log(context.Background(), logger.Debug(),
|
||||
"Processing adapter %s", aa.FriendlyName())
|
||||
|
||||
for a := aa.FirstUnicastAddress; a != nil; a = a.Next {
|
||||
ip := a.Address.IP().String()
|
||||
addressMap[ip] = struct{}{}
|
||||
Log(context.Background(), logger.Debug(),
|
||||
"Added local IP %s from adapter %s", ip, aa.FriendlyName())
|
||||
}
|
||||
}
|
||||
|
||||
validInterfacesMap := validInterfaces()
|
||||
|
||||
// Collect DNS servers
|
||||
for _, aa := range aas {
|
||||
if aa.OperStatus != winipcfg.IfOperStatusUp {
|
||||
continue
|
||||
}
|
||||
|
||||
// Skip if software loopback or other non-physical types
|
||||
// This is to avoid the "Loopback Pseudo-Interface 1" issue we see on windows
|
||||
if aa.IfType == winipcfg.IfTypeSoftwareLoopback {
|
||||
Log(context.Background(), logger.Debug(),
|
||||
"Skipping %s (software loopback)", aa.FriendlyName())
|
||||
continue
|
||||
}
|
||||
|
||||
// if not in the validInterfacesMap, skip
|
||||
if _, ok := validInterfacesMap[aa.FriendlyName()]; !ok {
|
||||
Log(context.Background(), logger.Debug(),
|
||||
"Skipping %s (not in validInterfacesMap)", aa.FriendlyName())
|
||||
continue
|
||||
}
|
||||
|
||||
for dns := aa.FirstDNSServerAddress; dns != nil; dns = dns.Next {
|
||||
ip := dns.Address.IP()
|
||||
if ip == nil {
|
||||
Log(context.Background(), logger.Debug(),
|
||||
"Skipping nil IP from adapter %s", aa.FriendlyName())
|
||||
continue
|
||||
}
|
||||
|
||||
ipStr := ip.String()
|
||||
l := logger.Debug().
|
||||
Str("ip", ipStr).
|
||||
Str("adapter", aa.FriendlyName())
|
||||
|
||||
if ip.IsLoopback() {
|
||||
l.Msg("Skipping loopback IP")
|
||||
continue
|
||||
}
|
||||
if seen[ipStr] {
|
||||
l.Msg("Skipping duplicate IP")
|
||||
continue
|
||||
}
|
||||
if _, ok := addressMap[ipStr]; ok {
|
||||
l.Msg("Skipping local interface IP")
|
||||
continue
|
||||
}
|
||||
|
||||
seen[ipStr] = true
|
||||
ns = append(ns, ipStr)
|
||||
l.Msg("Added DNS server")
|
||||
}
|
||||
}
|
||||
|
||||
// Add DC servers if they're not already in the list
|
||||
for _, dcServer := range dcServers {
|
||||
if !seen[dcServer] {
|
||||
seen[dcServer] = true
|
||||
ns = append(ns, dcServer)
|
||||
Log(context.Background(), logger.Debug(),
|
||||
"Added additional domain controller DNS server: %s", dcServer)
|
||||
}
|
||||
}
|
||||
|
||||
// if we have static DNS servers saved for the current default route, we should add them to the list
|
||||
drIfaceName, err := netmon.DefaultRouteInterface()
|
||||
if err != nil {
|
||||
Log(context.Background(), logger.Debug(),
|
||||
"Failed to get default route interface: %v", err)
|
||||
} else {
|
||||
drIface, err := net.InterfaceByName(drIfaceName)
|
||||
if err != nil {
|
||||
Log(context.Background(), logger.Debug(),
|
||||
"Failed to get interface by name %s: %v", drIfaceName, err)
|
||||
} else {
|
||||
staticNs, file := SavedStaticNameservers(drIface)
|
||||
Log(context.Background(), logger.Debug(),
|
||||
"static dns servers from %s: %v", file, staticNs)
|
||||
if len(staticNs) > 0 {
|
||||
Log(context.Background(), logger.Debug(),
|
||||
"Adding static DNS servers from %s: %v", drIfaceName, staticNs)
|
||||
ns = append(ns, staticNs...)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(ns) == 0 {
|
||||
return nil, fmt.Errorf("no valid DNS servers found")
|
||||
}
|
||||
|
||||
Log(context.Background(), logger.Debug(),
|
||||
"DNS server discovery completed, count=%d, servers=%v (including %d DC servers)",
|
||||
len(ns), ns, len(dcServers))
|
||||
return ns, nil
|
||||
}
|
||||
|
||||
// currentNameserversFromResolvconf returns a nil slice of strings.
|
||||
func currentNameserversFromResolvconf() []string {
|
||||
return nil
|
||||
}
|
||||
|
||||
// checkDomainJoined checks if the machine is joined to an Active Directory domain
|
||||
// Returns whether it's domain joined and the domain name if available
|
||||
func checkDomainJoined() bool {
|
||||
logger := *ProxyLogger.Load()
|
||||
|
||||
var domain *uint16
|
||||
var status uint32
|
||||
|
||||
if err := windows.NetGetJoinInformation(nil, &domain, &status); err != nil {
|
||||
Log(context.Background(), logger.Debug(), "Failed to get domain join status: %v", err)
|
||||
return false
|
||||
}
|
||||
defer windows.NetApiBufferFree((*byte)(unsafe.Pointer(domain)))
|
||||
|
||||
// NETSETUP_JOIN_STATUS constants from Microsoft Windows API
|
||||
// See: https://learn.microsoft.com/en-us/windows/win32/api/lmjoin/ne-lmjoin-netsetup_join_status
|
||||
//
|
||||
// NetSetupUnknownStatus uint32 = 0 // The status is unknown
|
||||
// NetSetupUnjoined uint32 = 1 // The computer is not joined to a domain or workgroup
|
||||
// NetSetupWorkgroupName uint32 = 2 // The computer is joined to a workgroup
|
||||
// NetSetupDomainName uint32 = 3 // The computer is joined to a domain
|
||||
//
|
||||
// We only care about NetSetupDomainName.
|
||||
domainName := windows.UTF16PtrToString(domain)
|
||||
Log(context.Background(), logger.Debug(),
|
||||
"Domain join status: domain=%s status=%d (UnknownStatus=0, Unjoined=1, WorkgroupName=2, DomainName=3)",
|
||||
domainName, status)
|
||||
|
||||
isDomain := status == syscall.NetSetupDomainName
|
||||
Log(context.Background(), logger.Debug(), "Is domain joined? status=%d, result=%v", status, isDomain)
|
||||
|
||||
return isDomain
|
||||
}
|
||||
|
||||
// getLocalADDomain uses Microsoft's WMI wrappers (github.com/microsoft/wmi/pkg/*)
|
||||
// to query the Domain field from Win32_ComputerSystem instead of a direct go-ole call.
|
||||
func getLocalADDomain() (string, error) {
|
||||
log.SetOutput(io.Discard)
|
||||
defer log.SetOutput(os.Stderr)
|
||||
// 1) Check environment variable
|
||||
envDomain := os.Getenv("USERDNSDOMAIN")
|
||||
if envDomain != "" {
|
||||
return strings.TrimSpace(envDomain), nil
|
||||
}
|
||||
|
||||
// 2) Query WMI via the microsoft/wmi library
|
||||
whost := host.NewWmiLocalHost()
|
||||
q := query.NewWmiQuery("Win32_ComputerSystem")
|
||||
instances, err := instance.GetWmiInstancesFromHost(whost, string(constant.CimV2), q)
|
||||
if instances != nil {
|
||||
defer instances.Close()
|
||||
}
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("WMI query failed: %v", err)
|
||||
}
|
||||
|
||||
// If no results, return an error
|
||||
if len(instances) == 0 {
|
||||
return "", fmt.Errorf("no rows returned from Win32_ComputerSystem")
|
||||
}
|
||||
|
||||
// We only care about the first row
|
||||
domainVal, err := instances[0].GetProperty("Domain")
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("machine does not appear to have a domain set: %v", err)
|
||||
}
|
||||
|
||||
domainName := strings.TrimSpace(fmt.Sprintf("%v", domainVal))
|
||||
if domainName == "" {
|
||||
return "", fmt.Errorf("machine does not appear to have a domain set")
|
||||
}
|
||||
return domainName, nil
|
||||
}
|
||||
|
||||
// validInterfaces returns a list of all physical interfaces.
|
||||
// this is a duplicate of what is in net_windows.go, we should
|
||||
// clean this up so there is only one version
|
||||
func validInterfaces() map[string]struct{} {
|
||||
log.SetOutput(io.Discard)
|
||||
defer log.SetOutput(os.Stderr)
|
||||
|
||||
//load the logger
|
||||
logger := *ProxyLogger.Load()
|
||||
|
||||
whost := host.NewWmiLocalHost()
|
||||
q := query.NewWmiQuery("MSFT_NetAdapter")
|
||||
instances, err := instance.GetWmiInstancesFromHost(whost, string(constant.StadardCimV2), q)
|
||||
if instances != nil {
|
||||
defer instances.Close()
|
||||
}
|
||||
if err != nil {
|
||||
Log(context.Background(), logger.Warn(),
|
||||
"failed to get wmi network adapter: %v", err)
|
||||
return nil
|
||||
}
|
||||
var adapters []string
|
||||
for _, i := range instances {
|
||||
adapter, err := netadapter.NewNetworkAdapter(i)
|
||||
if err != nil {
|
||||
Log(context.Background(), logger.Warn(),
|
||||
"failed to get network adapter: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
name, err := adapter.GetPropertyName()
|
||||
if err != nil {
|
||||
Log(context.Background(), logger.Warn(),
|
||||
"failed to get interface name: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
// From: https://learn.microsoft.com/en-us/previous-versions/windows/desktop/legacy/hh968170(v=vs.85)
|
||||
//
|
||||
// "Indicates if a connector is present on the network adapter. This value is set to TRUE
|
||||
// if this is a physical adapter or FALSE if this is not a physical adapter."
|
||||
physical, err := adapter.GetPropertyConnectorPresent()
|
||||
if err != nil {
|
||||
Log(context.Background(), logger.Debug(),
|
||||
"failed to get network adapter connector present property: %v", err)
|
||||
continue
|
||||
}
|
||||
if !physical {
|
||||
Log(context.Background(), logger.Debug(),
|
||||
"skipping non-physical adapter: %s", name)
|
||||
continue
|
||||
}
|
||||
|
||||
// Check if it's a hardware interface. Checking only for connector present is not enough
|
||||
// because some interfaces are not physical but have a connector.
|
||||
hardware, err := adapter.GetPropertyHardwareInterface()
|
||||
if err != nil {
|
||||
Log(context.Background(), logger.Debug(),
|
||||
"failed to get network adapter hardware interface property: %v", err)
|
||||
continue
|
||||
}
|
||||
if !hardware {
|
||||
Log(context.Background(), logger.Debug(),
|
||||
"skipping non-hardware interface: %s", name)
|
||||
continue
|
||||
}
|
||||
|
||||
adapters = append(adapters, name)
|
||||
}
|
||||
|
||||
m := make(map[string]struct{})
|
||||
for _, ifaceName := range adapters {
|
||||
m[ifaceName] = struct{}{}
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
48
net.go
48
net.go
@@ -6,6 +6,8 @@ import (
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"tailscale.com/net/netmon"
|
||||
|
||||
ctrldnet "github.com/Control-D-Inc/ctrld/internal/net"
|
||||
)
|
||||
|
||||
@@ -14,36 +16,38 @@ var (
|
||||
ipv6Available atomic.Bool
|
||||
)
|
||||
|
||||
const ipv6ProbingInterval = 10 * time.Second
|
||||
|
||||
func hasIPv6() bool {
|
||||
// HasIPv6 reports whether the current network stack has IPv6 available.
|
||||
func HasIPv6() bool {
|
||||
hasIPv6Once.Do(func() {
|
||||
ProxyLogger.Load().Debug().Msg("checking for IPv6 availability once")
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
val := ctrldnet.IPv6Available(ctx)
|
||||
ipv6Available.Store(val)
|
||||
go probingIPv6(context.TODO(), val)
|
||||
ProxyLogger.Load().Debug().Msgf("ipv6 availability: %v", val)
|
||||
mon, err := netmon.New(func(format string, args ...any) {})
|
||||
if err != nil {
|
||||
ProxyLogger.Load().Debug().Err(err).Msg("failed to monitor IPv6 state")
|
||||
return
|
||||
}
|
||||
mon.RegisterChangeCallback(func(delta *netmon.ChangeDelta) {
|
||||
old := ipv6Available.Load()
|
||||
cur := delta.Monitor.InterfaceState().HaveV6
|
||||
if old != cur {
|
||||
ProxyLogger.Load().Warn().Msgf("ipv6 availability changed, old: %v, new: %v", old, cur)
|
||||
} else {
|
||||
ProxyLogger.Load().Debug().Msg("ipv6 availability does not changed")
|
||||
}
|
||||
ipv6Available.Store(cur)
|
||||
})
|
||||
mon.Start()
|
||||
})
|
||||
return ipv6Available.Load()
|
||||
}
|
||||
|
||||
// TODO(cuonglm): doing poll check natively for supported platforms.
|
||||
func probingIPv6(ctx context.Context, old bool) {
|
||||
ticker := time.NewTicker(ipv6ProbingInterval)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
func() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
cur := ctrldnet.IPv6Available(ctx)
|
||||
if ipv6Available.CompareAndSwap(old, cur) {
|
||||
old = cur
|
||||
}
|
||||
}()
|
||||
}
|
||||
// DisableIPv6 marks IPv6 as unavailable if enabled.
|
||||
func DisableIPv6() {
|
||||
if ipv6Available.CompareAndSwap(true, false) {
|
||||
ProxyLogger.Load().Debug().Msg("turned off IPv6 availability")
|
||||
}
|
||||
}
|
||||
|
||||
35
net_darwin.go
Normal file
35
net_darwin.go
Normal file
@@ -0,0 +1,35 @@
|
||||
package ctrld
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"io"
|
||||
"os/exec"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// validInterfaces returns a set of all valid hardware ports.
|
||||
// TODO: deduplicated with cmd/cli/net_darwin.go in v2.
|
||||
func validInterfaces() map[string]struct{} {
|
||||
b, err := exec.Command("networksetup", "-listallhardwareports").Output()
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
return parseListAllHardwarePorts(bytes.NewReader(b))
|
||||
}
|
||||
|
||||
// parseListAllHardwarePorts parses output of "networksetup -listallhardwareports"
|
||||
// and returns map presents all hardware ports.
|
||||
func parseListAllHardwarePorts(r io.Reader) map[string]struct{} {
|
||||
m := make(map[string]struct{})
|
||||
scanner := bufio.NewScanner(r)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
after, ok := strings.CutPrefix(line, "Device: ")
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
m[after] = struct{}{}
|
||||
}
|
||||
return m
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package cli
|
||||
package ctrld
|
||||
|
||||
import (
|
||||
"maps"
|
||||
15
net_others.go
Normal file
15
net_others.go
Normal file
@@ -0,0 +1,15 @@
|
||||
//go:build !darwin && !windows && !linux
|
||||
|
||||
package ctrld
|
||||
|
||||
import "tailscale.com/net/netmon"
|
||||
|
||||
// validInterfaces returns a set containing only default route interfaces.
|
||||
// TODO: deuplicated with cmd/cli/net_others.go in v2.
|
||||
func validInterfaces() map[string]struct{} {
|
||||
defaultRoute, err := netmon.DefaultRoute()
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
return map[string]struct{}{defaultRoute.InterfaceName: {}}
|
||||
}
|
||||
497
resolver.go
497
resolver.go
@@ -4,8 +4,10 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/netip"
|
||||
"runtime"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
@@ -13,6 +15,8 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"github.com/rs/zerolog"
|
||||
"golang.org/x/sync/singleflight"
|
||||
"tailscale.com/net/netmon"
|
||||
"tailscale.com/net/tsaddr"
|
||||
)
|
||||
@@ -30,22 +34,51 @@ const (
|
||||
ResolverTypeOS = "os"
|
||||
// ResolverTypeLegacy specifies legacy resolver.
|
||||
ResolverTypeLegacy = "legacy"
|
||||
// ResolverTypePrivate is like ResolverTypeOS, but use for local resolver only.
|
||||
// ResolverTypePrivate is like ResolverTypeOS, but use for private resolver only.
|
||||
ResolverTypePrivate = "private"
|
||||
// ResolverTypeLocal is like ResolverTypeOS, but use for local resolver only.
|
||||
ResolverTypeLocal = "local"
|
||||
// ResolverTypeSDNS specifies resolver with information encoded using DNS Stamps.
|
||||
// See: https://dnscrypt.info/stamps-specifications/
|
||||
ResolverTypeSDNS = "sdns"
|
||||
)
|
||||
|
||||
const (
|
||||
controldBootstrapDns = "76.76.2.22"
|
||||
controldPublicDns = "76.76.2.0"
|
||||
)
|
||||
const controldPublicDns = "76.76.2.0"
|
||||
|
||||
var controldPublicDnsWithPort = net.JoinHostPort(controldPublicDns, "53")
|
||||
|
||||
// or is the Resolver used for ResolverTypeOS.
|
||||
var or = newResolverWithNameserver(defaultNameservers())
|
||||
var localResolver Resolver
|
||||
|
||||
func init() {
|
||||
// Initializing ProxyLogger here, so other places don't have to do nil check.
|
||||
l := zerolog.New(io.Discard)
|
||||
ProxyLogger.Store(&l)
|
||||
|
||||
localResolver = newLocalResolver()
|
||||
}
|
||||
|
||||
var (
|
||||
resolverMutex sync.Mutex
|
||||
or *osResolver
|
||||
defaultLocalIPv4 atomic.Value // holds net.IP (IPv4)
|
||||
defaultLocalIPv6 atomic.Value // holds net.IP (IPv6)
|
||||
)
|
||||
|
||||
func newLocalResolver() Resolver {
|
||||
var nss []string
|
||||
for _, addr := range Rfc1918Addresses() {
|
||||
nss = append(nss, net.JoinHostPort(addr, "53"))
|
||||
}
|
||||
return NewResolverWithNameserver(nss)
|
||||
}
|
||||
|
||||
// LanQueryCtxKey is the context.Context key to indicate that the request is for LAN network.
|
||||
type LanQueryCtxKey struct{}
|
||||
|
||||
// LanQueryCtx returns a context.Context with LanQueryCtxKey set.
|
||||
func LanQueryCtx(ctx context.Context) context.Context {
|
||||
return context.WithValue(ctx, LanQueryCtxKey{}, true)
|
||||
}
|
||||
|
||||
// defaultNameservers is like nameservers with each element formed "ip:53".
|
||||
func defaultNameservers() []string {
|
||||
@@ -63,17 +96,37 @@ func availableNameservers() []string {
|
||||
// Ignore local addresses to prevent loop.
|
||||
regularIPs, loopbackIPs, _ := netmon.LocalAddresses()
|
||||
machineIPsMap := make(map[string]struct{}, len(regularIPs))
|
||||
|
||||
//load the logger
|
||||
logger := *ProxyLogger.Load()
|
||||
|
||||
Log(context.Background(), logger.Debug(),
|
||||
"Got local addresses - regular IPs: %v, loopback IPs: %v", regularIPs, loopbackIPs)
|
||||
|
||||
for _, v := range slices.Concat(regularIPs, loopbackIPs) {
|
||||
machineIPsMap[v.String()] = struct{}{}
|
||||
ipStr := v.String()
|
||||
machineIPsMap[ipStr] = struct{}{}
|
||||
Log(context.Background(), logger.Debug(),
|
||||
"Added local IP to OS resolverexclusion map: %s", ipStr)
|
||||
}
|
||||
for _, ns := range nameservers() {
|
||||
|
||||
systemNameservers := nameservers()
|
||||
Log(context.Background(), logger.Debug(),
|
||||
"Got system nameservers: %v", systemNameservers)
|
||||
|
||||
for _, ns := range systemNameservers {
|
||||
if _, ok := machineIPsMap[ns]; ok {
|
||||
Log(context.Background(), logger.Debug(),
|
||||
"Skipping local nameserver: %s", ns)
|
||||
continue
|
||||
}
|
||||
if testNameserver(ns) {
|
||||
nss = append(nss, ns)
|
||||
}
|
||||
nss = append(nss, ns)
|
||||
Log(context.Background(), logger.Debug(),
|
||||
"Added non-local nameserver: %s", ns)
|
||||
}
|
||||
|
||||
Log(context.Background(), logger.Debug(),
|
||||
"Final available nameservers: %v", nss)
|
||||
return nss
|
||||
}
|
||||
|
||||
@@ -82,77 +135,47 @@ func availableNameservers() []string {
|
||||
//
|
||||
// It's the caller's responsibility to ensure the system DNS is in a clean state before
|
||||
// calling this function.
|
||||
func InitializeOsResolver() []string {
|
||||
return initializeOsResolver(availableNameservers())
|
||||
func InitializeOsResolver(guardAgainstNoNameservers bool) []string {
|
||||
nameservers := availableNameservers()
|
||||
// if no nameservers, return empty slice so we dont remove all nameservers
|
||||
if len(nameservers) == 0 && guardAgainstNoNameservers {
|
||||
return []string{}
|
||||
}
|
||||
ns := initializeOsResolver(nameservers)
|
||||
resolverMutex.Lock()
|
||||
defer resolverMutex.Unlock()
|
||||
or = newResolverWithNameserver(ns)
|
||||
return ns
|
||||
}
|
||||
|
||||
// initializeOsResolver performs logic for choosing OS resolver nameserver.
|
||||
// The logic:
|
||||
//
|
||||
// - First available LAN servers are saved and store.
|
||||
// - Later calls, if no LAN servers available, the saved servers above will be used.
|
||||
func initializeOsResolver(servers []string) []string {
|
||||
var (
|
||||
nss []string
|
||||
publicNss []string
|
||||
)
|
||||
var (
|
||||
lastLanServer netip.Addr
|
||||
curLanServer netip.Addr
|
||||
curLanServerAvailable bool
|
||||
)
|
||||
if p := or.currentLanServer.Load(); p != nil {
|
||||
curLanServer = *p
|
||||
or.currentLanServer.Store(nil)
|
||||
}
|
||||
if p := or.lastLanServer.Load(); p != nil {
|
||||
lastLanServer = *p
|
||||
or.lastLanServer.Store(nil)
|
||||
}
|
||||
|
||||
var lanNss, publicNss []string
|
||||
|
||||
// First categorize servers
|
||||
for _, ns := range servers {
|
||||
addr, err := netip.ParseAddr(ns)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
server := net.JoinHostPort(ns, "53")
|
||||
// Always use new public nameserver.
|
||||
if !isLanAddr(addr) {
|
||||
publicNss = append(publicNss, server)
|
||||
nss = append(nss, server)
|
||||
continue
|
||||
}
|
||||
// For LAN server, storing only current and last LAN server if any.
|
||||
if addr.Compare(curLanServer) == 0 {
|
||||
curLanServerAvailable = true
|
||||
if isLanAddr(addr) {
|
||||
lanNss = append(lanNss, server)
|
||||
} else {
|
||||
if addr.Compare(lastLanServer) == 0 {
|
||||
or.lastLanServer.Store(&addr)
|
||||
} else {
|
||||
if or.currentLanServer.CompareAndSwap(nil, &addr) {
|
||||
nss = append(nss, server)
|
||||
}
|
||||
}
|
||||
publicNss = append(publicNss, server)
|
||||
}
|
||||
}
|
||||
// Store current LAN server as last one only if it's still available.
|
||||
if curLanServerAvailable && curLanServer.IsValid() {
|
||||
or.lastLanServer.Store(&curLanServer)
|
||||
nss = append(nss, net.JoinHostPort(curLanServer.String(), "53"))
|
||||
}
|
||||
if len(publicNss) == 0 {
|
||||
publicNss = append(publicNss, controldPublicDnsWithPort)
|
||||
nss = append(nss, controldPublicDnsWithPort)
|
||||
}
|
||||
or.publicServer.Store(&publicNss)
|
||||
return nss
|
||||
}
|
||||
|
||||
// testPlainDnsNameserver sends a test query to DNS nameserver to check if the server is available.
|
||||
func testNameserver(addr string) bool {
|
||||
msg := new(dns.Msg)
|
||||
msg.SetQuestion("controld.com.", dns.TypeNS)
|
||||
client := new(dns.Client)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel()
|
||||
_, _, err := client.ExchangeContext(ctx, msg, net.JoinHostPort(addr, "53"))
|
||||
if err != nil {
|
||||
ProxyLogger.Load().Debug().Err(err).Msgf("failed to connect to OS nameserver: %s", addr)
|
||||
if len(publicNss) == 0 {
|
||||
publicNss = []string{controldPublicDnsWithPort}
|
||||
}
|
||||
return err == nil
|
||||
|
||||
return slices.Concat(lanNss, publicNss)
|
||||
}
|
||||
|
||||
// Resolver is the interface that wraps the basic DNS operations.
|
||||
@@ -175,19 +198,28 @@ func NewResolver(uc *UpstreamConfig) (Resolver, error) {
|
||||
case ResolverTypeDOQ:
|
||||
return &doqResolver{uc: uc}, nil
|
||||
case ResolverTypeOS:
|
||||
resolverMutex.Lock()
|
||||
if or == nil {
|
||||
ProxyLogger.Load().Debug().Msgf("Initialize new OS resolver")
|
||||
or = newResolverWithNameserver(defaultNameservers())
|
||||
}
|
||||
resolverMutex.Unlock()
|
||||
return or, nil
|
||||
case ResolverTypeLegacy:
|
||||
return &legacyResolver{uc: uc}, nil
|
||||
case ResolverTypePrivate:
|
||||
return NewPrivateResolver(), nil
|
||||
case ResolverTypeLocal:
|
||||
return localResolver, nil
|
||||
}
|
||||
return nil, fmt.Errorf("%w: %s", errUnknownResolver, typ)
|
||||
}
|
||||
|
||||
type osResolver struct {
|
||||
currentLanServer atomic.Pointer[netip.Addr]
|
||||
lastLanServer atomic.Pointer[netip.Addr]
|
||||
publicServer atomic.Pointer[[]string]
|
||||
lanServers atomic.Pointer[[]string]
|
||||
publicServers atomic.Pointer[[]string]
|
||||
group *singleflight.Group
|
||||
cache *sync.Map
|
||||
}
|
||||
|
||||
type osResolverResult struct {
|
||||
@@ -197,26 +229,155 @@ type osResolverResult struct {
|
||||
lan bool
|
||||
}
|
||||
|
||||
// Resolve resolves DNS queries using pre-configured nameservers.
|
||||
// Query is sent to all nameservers concurrently, and the first
|
||||
// success response will be returned.
|
||||
func (o *osResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) {
|
||||
publicServers := *o.publicServer.Load()
|
||||
nss := make([]string, 0, 2)
|
||||
if p := o.currentLanServer.Load(); p != nil {
|
||||
nss = append(nss, net.JoinHostPort(p.String(), "53"))
|
||||
type publicResponse struct {
|
||||
answer *dns.Msg
|
||||
server string
|
||||
}
|
||||
|
||||
// SetDefaultLocalIPv4 updates the stored local IPv4.
|
||||
func SetDefaultLocalIPv4(ip net.IP) {
|
||||
Log(context.Background(), ProxyLogger.Load().Debug(), "SetDefaultLocalIPv4: %s", ip)
|
||||
defaultLocalIPv4.Store(ip)
|
||||
}
|
||||
|
||||
// SetDefaultLocalIPv6 updates the stored local IPv6.
|
||||
func SetDefaultLocalIPv6(ip net.IP) {
|
||||
Log(context.Background(), ProxyLogger.Load().Debug(), "SetDefaultLocalIPv6: %s", ip)
|
||||
defaultLocalIPv6.Store(ip)
|
||||
}
|
||||
|
||||
// GetDefaultLocalIPv4 returns the stored local IPv4 or nil if none.
|
||||
func GetDefaultLocalIPv4() net.IP {
|
||||
if v := defaultLocalIPv4.Load(); v != nil {
|
||||
return v.(net.IP)
|
||||
}
|
||||
if p := o.lastLanServer.Load(); p != nil {
|
||||
nss = append(nss, net.JoinHostPort(p.String(), "53"))
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetDefaultLocalIPv6 returns the stored local IPv6 or nil if none.
|
||||
func GetDefaultLocalIPv6() net.IP {
|
||||
if v := defaultLocalIPv6.Load(); v != nil {
|
||||
return v.(net.IP)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// customDNSExchange wraps the DNS exchange to use our debug dialer.
|
||||
// It uses dns.ExchangeWithConn so that our custom dialer is used directly.
|
||||
func customDNSExchange(ctx context.Context, msg *dns.Msg, server string, desiredLocalIP net.IP) (*dns.Msg, time.Duration, error) {
|
||||
baseDialer := &net.Dialer{
|
||||
Timeout: 3 * time.Second,
|
||||
Resolver: &net.Resolver{PreferGo: true},
|
||||
}
|
||||
if desiredLocalIP != nil {
|
||||
baseDialer.LocalAddr = &net.UDPAddr{IP: desiredLocalIP, Port: 0}
|
||||
}
|
||||
dnsClient := &dns.Client{Net: "udp"}
|
||||
dnsClient.Dialer = baseDialer
|
||||
return dnsClient.ExchangeContext(ctx, msg, server)
|
||||
}
|
||||
|
||||
const hotCacheTTL = time.Second
|
||||
|
||||
// Resolve resolves DNS queries using pre-configured nameservers.
|
||||
// The Query is sent to all nameservers concurrently, and the first
|
||||
// success response will be returned.
|
||||
//
|
||||
// To guard against unexpected DoS to upstreams, multiple queries of
|
||||
// the same Qtype to a domain will be shared, so there's only 1 qps
|
||||
// for each upstream at any time.
|
||||
//
|
||||
// Further, a hot cache will be used, so repeated queries will be cached
|
||||
// for a short period (currently 1 second), reducing unnecessary traffics
|
||||
// sent to upstreams.
|
||||
func (o *osResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) {
|
||||
if len(msg.Question) == 0 {
|
||||
return nil, errors.New("no question found")
|
||||
}
|
||||
domain := strings.TrimSuffix(msg.Question[0].Name, ".")
|
||||
qtype := msg.Question[0].Qtype
|
||||
|
||||
// Unique key for the singleflight group.
|
||||
key := fmt.Sprintf("%s:%d:", domain, qtype)
|
||||
|
||||
// Checking the cache first.
|
||||
if val, ok := o.cache.Load(key); ok {
|
||||
if val, ok := val.(*dns.Msg); ok {
|
||||
Log(ctx, ProxyLogger.Load().Debug(), "hit hot cached result: %s - %s", domain, dns.TypeToString[qtype])
|
||||
res := val.Copy()
|
||||
SetCacheReply(res, msg, val.Rcode)
|
||||
return res, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure only one DNS query is in flight for the key.
|
||||
v, err, shared := o.group.Do(key, func() (interface{}, error) {
|
||||
msg, err := o.resolve(ctx, msg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// If we got an answer, storing it to the hot cache for hotCacheTTL
|
||||
// This prevents possible DoS to upstream, ensuring there's only 1 QPS.
|
||||
o.cache.Store(key, msg)
|
||||
// Depends on go runtime scheduling, the result may end up in hot cache longer
|
||||
// than hotCacheTTL duration. However, this is fine since we only want to guard
|
||||
// against DoS attack. The result will be cleaned from the cache eventually.
|
||||
time.AfterFunc(hotCacheTTL, func() {
|
||||
o.removeCache(key)
|
||||
})
|
||||
return msg, nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
sharedMsg, ok := v.(*dns.Msg)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid answer for key: %s", key)
|
||||
}
|
||||
res := sharedMsg.Copy()
|
||||
SetCacheReply(res, msg, sharedMsg.Rcode)
|
||||
if shared {
|
||||
Log(ctx, ProxyLogger.Load().Debug(), "shared result: %s - %s", domain, dns.TypeToString[qtype])
|
||||
}
|
||||
|
||||
return res, nil
|
||||
}
|
||||
|
||||
// resolve sends the query to current nameservers.
|
||||
func (o *osResolver) resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) {
|
||||
publicServers := *o.publicServers.Load()
|
||||
var nss []string
|
||||
if p := o.lanServers.Load(); p != nil {
|
||||
nss = append(nss, (*p)...)
|
||||
}
|
||||
numServers := len(nss) + len(publicServers)
|
||||
if numServers == 0 {
|
||||
return nil, errors.New("no nameservers available")
|
||||
|
||||
// If this is a LAN query, skip public DNS.
|
||||
lan, ok := ctx.Value(LanQueryCtxKey{}).(bool)
|
||||
|
||||
// remove controldPublicDnsWithPort from publicServers for LAN queries
|
||||
// this is to prevent DoS for high frequency local requests
|
||||
if ok && lan {
|
||||
if index := slices.Index(publicServers, controldPublicDnsWithPort); index != -1 {
|
||||
publicServers = slices.Delete(publicServers, index, index+1)
|
||||
numServers--
|
||||
}
|
||||
}
|
||||
question := ""
|
||||
if msg != nil && len(msg.Question) > 0 {
|
||||
question = msg.Question[0].Name
|
||||
}
|
||||
Log(ctx, ProxyLogger.Load().Debug(), "os resolver query for %s with nameservers: %v public: %v", question, nss, publicServers)
|
||||
|
||||
// New check: If no resolvers are available, return an error.
|
||||
if numServers == 0 {
|
||||
return nil, errors.New("no nameservers available for query")
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
dnsClient := &dns.Client{Net: "udp"}
|
||||
ch := make(chan *osResolverResult, numServers)
|
||||
wg := &sync.WaitGroup{}
|
||||
wg.Add(numServers)
|
||||
@@ -229,70 +390,127 @@ func (o *osResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error
|
||||
for _, server := range servers {
|
||||
go func(server string) {
|
||||
defer wg.Done()
|
||||
answer, _, err := dnsClient.ExchangeContext(ctx, msg.Copy(), server)
|
||||
var answer *dns.Msg
|
||||
var err error
|
||||
var localOSResolverIP net.IP
|
||||
if runtime.GOOS == "darwin" {
|
||||
host, _, err := net.SplitHostPort(server)
|
||||
if err == nil {
|
||||
ip := net.ParseIP(host)
|
||||
if ip != nil && ip.To4() == nil {
|
||||
// IPv6 nameserver; use default IPv6 address (if set)
|
||||
localOSResolverIP = GetDefaultLocalIPv6()
|
||||
} else {
|
||||
localOSResolverIP = GetDefaultLocalIPv4()
|
||||
}
|
||||
}
|
||||
}
|
||||
answer, _, err = customDNSExchange(ctx, msg.Copy(), server, localOSResolverIP)
|
||||
ch <- &osResolverResult{answer: answer, err: err, server: server, lan: isLan}
|
||||
}(server)
|
||||
}
|
||||
}
|
||||
do(nss, true)
|
||||
do(publicServers, false)
|
||||
|
||||
logAnswer := func(server string) {
|
||||
if before, _, found := strings.Cut(server, ":"); found {
|
||||
server = before
|
||||
host, _, err := net.SplitHostPort(server)
|
||||
if err != nil {
|
||||
// If splitting fails, fallback to the original server string
|
||||
host = server
|
||||
}
|
||||
Log(ctx, ProxyLogger.Load().Debug(), "got answer from nameserver: %s", server)
|
||||
Log(ctx, ProxyLogger.Load().Debug(), "got answer from nameserver: %s", host)
|
||||
}
|
||||
|
||||
// try local nameservers
|
||||
if len(nss) > 0 {
|
||||
do(nss, true)
|
||||
}
|
||||
|
||||
// we must always try the public servers too, since DCHP may have only public servers
|
||||
// this is okay to do since we always prefer LAN nameserver responses
|
||||
if len(publicServers) > 0 {
|
||||
do(publicServers, false)
|
||||
}
|
||||
|
||||
var (
|
||||
nonSuccessAnswer *dns.Msg
|
||||
nonSuccessServer string
|
||||
controldSuccessAnswer *dns.Msg
|
||||
publicServerAnswer *dns.Msg
|
||||
publicServer string
|
||||
publicResponses []publicResponse
|
||||
)
|
||||
errs := make([]error, 0, numServers)
|
||||
for res := range ch {
|
||||
switch {
|
||||
case res.answer != nil && res.answer.Rcode == dns.RcodeSuccess:
|
||||
switch {
|
||||
case res.lan:
|
||||
// Always prefer LAN responses immediately
|
||||
Log(ctx, ProxyLogger.Load().Debug(), "using LAN answer from: %s", res.server)
|
||||
cancel()
|
||||
logAnswer(res.server)
|
||||
return res.answer, nil
|
||||
case res.server == controldPublicDnsWithPort:
|
||||
controldSuccessAnswer = res.answer // only use ControlD answer as last one.
|
||||
case !res.lan && publicServerAnswer == nil:
|
||||
publicServerAnswer = res.answer // use public DNS answer after LAN server..
|
||||
publicServer = res.server
|
||||
default:
|
||||
controldSuccessAnswer = res.answer
|
||||
case !res.lan:
|
||||
// if there are no LAN nameservers, we should not wait
|
||||
// just use the first response
|
||||
if len(nss) == 0 {
|
||||
Log(ctx, ProxyLogger.Load().Debug(), "using public answer from: %s", res.server)
|
||||
cancel()
|
||||
logAnswer(res.server)
|
||||
return res.answer, nil
|
||||
}
|
||||
publicResponses = append(publicResponses, publicResponse{
|
||||
answer: res.answer,
|
||||
server: res.server,
|
||||
})
|
||||
}
|
||||
case res.answer != nil:
|
||||
Log(ctx, ProxyLogger.Load().Debug(), "got non-success answer from: %s with code: %d",
|
||||
res.server, res.answer.Rcode)
|
||||
// When there are no LAN nameservers, we should not wait
|
||||
// for other nameservers to respond.
|
||||
if len(nss) == 0 {
|
||||
Log(ctx, ProxyLogger.Load().Debug(), "no lan nameservers using public non success answer")
|
||||
cancel()
|
||||
logAnswer(res.server)
|
||||
return res.answer, nil
|
||||
}
|
||||
case res.answer != nil:
|
||||
nonSuccessAnswer = res.answer
|
||||
nonSuccessServer = res.server
|
||||
}
|
||||
errs = append(errs, res.err)
|
||||
}
|
||||
if publicServerAnswer != nil {
|
||||
logAnswer(publicServer)
|
||||
return publicServerAnswer, nil
|
||||
|
||||
if len(publicResponses) > 0 {
|
||||
resp := publicResponses[0]
|
||||
Log(ctx, ProxyLogger.Load().Debug(), "using public answer from: %s", resp.server)
|
||||
logAnswer(resp.server)
|
||||
return resp.answer, nil
|
||||
}
|
||||
if controldSuccessAnswer != nil {
|
||||
Log(ctx, ProxyLogger.Load().Debug(), "using ControlD answer from: %s", controldPublicDnsWithPort)
|
||||
logAnswer(controldPublicDnsWithPort)
|
||||
return controldSuccessAnswer, nil
|
||||
}
|
||||
if nonSuccessAnswer != nil {
|
||||
Log(ctx, ProxyLogger.Load().Debug(), "using non-success answer from: %s", nonSuccessServer)
|
||||
logAnswer(nonSuccessServer)
|
||||
return nonSuccessAnswer, nil
|
||||
}
|
||||
return nil, errors.Join(errs...)
|
||||
}
|
||||
|
||||
func (o *osResolver) removeCache(key string) {
|
||||
o.cache.Delete(key)
|
||||
}
|
||||
|
||||
type legacyResolver struct {
|
||||
uc *UpstreamConfig
|
||||
}
|
||||
|
||||
func (r *legacyResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) {
|
||||
// See comment in (*dotResolver).resolve method.
|
||||
dialer := newDialer(net.JoinHostPort(controldBootstrapDns, "53"))
|
||||
dialer := newDialer(net.JoinHostPort(controldPublicDns, "53"))
|
||||
dnsTyp := uint16(0)
|
||||
if msg != nil && len(msg.Question) > 0 {
|
||||
dnsTyp = msg.Question[0].Qtype
|
||||
@@ -321,19 +539,41 @@ func (d dummyResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, err
|
||||
return ans, nil
|
||||
}
|
||||
|
||||
// LookupIP looks up host using OS resolver.
|
||||
// LookupIP looks up domain using current system nameservers settings.
|
||||
// It returns a slice of that host's IPv4 and IPv6 addresses.
|
||||
func LookupIP(domain string) []string {
|
||||
return lookupIP(domain, -1, true)
|
||||
nss := initDefaultOsResolver()
|
||||
return lookupIP(domain, -1, nss)
|
||||
}
|
||||
|
||||
func lookupIP(domain string, timeout int, withBootstrapDNS bool) (ips []string) {
|
||||
nss := defaultNameservers()
|
||||
if withBootstrapDNS {
|
||||
nss = append([]string{net.JoinHostPort(controldBootstrapDns, "53")}, nss...)
|
||||
// initDefaultOsResolver initializes the default OS resolver with system's default nameservers if it hasn't been initialized yet.
|
||||
// It returns the combined list of LAN and public nameservers currently held by the resolver.
|
||||
func initDefaultOsResolver() []string {
|
||||
resolverMutex.Lock()
|
||||
defer resolverMutex.Unlock()
|
||||
if or == nil {
|
||||
ProxyLogger.Load().Debug().Msgf("Initialize new OS resolver with default nameservers")
|
||||
or = newResolverWithNameserver(defaultNameservers())
|
||||
}
|
||||
resolver := newResolverWithNameserver(nss)
|
||||
ProxyLogger.Load().Debug().Msgf("resolving %q using bootstrap DNS %q", domain, nss)
|
||||
nss := *or.lanServers.Load()
|
||||
nss = append(nss, *or.publicServers.Load()...)
|
||||
return nss
|
||||
}
|
||||
|
||||
// lookupIP looks up domain with given timeout and bootstrapDNS.
|
||||
// If the timeout is negative, default timeout 2000 ms will be used.
|
||||
// It returns nil if bootstrapDNS is nil or empty.
|
||||
func lookupIP(domain string, timeout int, bootstrapDNS []string) (ips []string) {
|
||||
if net.ParseIP(domain) != nil {
|
||||
return []string{domain}
|
||||
}
|
||||
if bootstrapDNS == nil {
|
||||
ProxyLogger.Load().Debug().Msgf("empty bootstrap DNS")
|
||||
return nil
|
||||
}
|
||||
|
||||
resolver := newResolverWithNameserver(bootstrapDNS)
|
||||
ProxyLogger.Load().Debug().Msgf("resolving %q using bootstrap DNS %q", domain, bootstrapDNS)
|
||||
timeoutMs := 2000
|
||||
if timeout > 0 && timeout < timeoutMs {
|
||||
timeoutMs = timeout
|
||||
@@ -406,6 +646,9 @@ func lookupIP(domain string, timeout int, withBootstrapDNS bool) (ips []string)
|
||||
// - Gateway IP address (depends on OS).
|
||||
// - Input servers.
|
||||
func NewBootstrapResolver(servers ...string) Resolver {
|
||||
logger := *ProxyLogger.Load()
|
||||
|
||||
Log(context.Background(), logger.Debug(), "NewBootstrapResolver called with servers: %v", servers)
|
||||
nss := defaultNameservers()
|
||||
nss = append([]string{controldPublicDnsWithPort}, nss...)
|
||||
for _, ns := range servers {
|
||||
@@ -422,8 +665,8 @@ func NewBootstrapResolver(servers ...string) Resolver {
|
||||
//
|
||||
// This is useful for doing PTR lookup in LAN network.
|
||||
func NewPrivateResolver() Resolver {
|
||||
nss := defaultNameservers()
|
||||
resolveConfNss := nameserversFromResolvconf()
|
||||
nss := initDefaultOsResolver()
|
||||
resolveConfNss := currentNameserversFromResolvconf()
|
||||
localRfc1918Addrs := Rfc1918Addresses()
|
||||
n := 0
|
||||
for _, ns := range nss {
|
||||
@@ -466,25 +709,35 @@ func NewResolverWithNameserver(nameservers []string) Resolver {
|
||||
// newResolverWithNameserver returns an OS resolver from given nameservers list.
|
||||
// The caller must ensure each server in list is formed "ip:53".
|
||||
func newResolverWithNameserver(nameservers []string) *osResolver {
|
||||
r := &osResolver{}
|
||||
nss := slices.Sorted(slices.Values(nameservers))
|
||||
for i, ns := range nss {
|
||||
r := &osResolver{
|
||||
group: &singleflight.Group{},
|
||||
cache: &sync.Map{},
|
||||
}
|
||||
var publicNss []string
|
||||
var lanNss []string
|
||||
for _, ns := range slices.Sorted(slices.Values(nameservers)) {
|
||||
ip, _, _ := net.SplitHostPort(ns)
|
||||
addr, _ := netip.ParseAddr(ip)
|
||||
if isLanAddr(addr) {
|
||||
r.currentLanServer.Store(&addr)
|
||||
nss = slices.Delete(nss, i, i+1)
|
||||
break
|
||||
lanNss = append(lanNss, ns)
|
||||
} else {
|
||||
publicNss = append(publicNss, ns)
|
||||
}
|
||||
}
|
||||
r.publicServer.Store(&nss)
|
||||
r.lanServers.Store(&lanNss)
|
||||
r.publicServers.Store(&publicNss)
|
||||
return r
|
||||
}
|
||||
|
||||
// Rfc1918Addresses returns the list of local interfaces private IP addresses
|
||||
// Rfc1918Addresses returns the list of local physical interfaces private IP addresses
|
||||
func Rfc1918Addresses() []string {
|
||||
vis := validInterfaces()
|
||||
var res []string
|
||||
netmon.ForeachInterface(func(i netmon.Interface, prefixes []netip.Prefix) {
|
||||
// Skip virtual interfaces.
|
||||
if _, existed := vis[i.Name]; !existed {
|
||||
return
|
||||
}
|
||||
addrs, _ := i.Addrs()
|
||||
for _, addr := range addrs {
|
||||
ipNet, ok := addr.(*net.IPNet)
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user