mirror of
https://github.com/Control-D-Inc/ctrld.git
synced 2026-02-03 22:18:39 +00:00
Compare commits
104 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5553490b27 | ||
|
|
eaf39f48a0 | ||
|
|
a5ddbdcb42 | ||
|
|
0c99d27be5 | ||
|
|
b9eb89c02e | ||
|
|
53f8d006f0 | ||
|
|
929de49c7b | ||
|
|
542c4f7daf | ||
|
|
c941f9c621 | ||
|
|
25eae187db | ||
|
|
726a25a7ea | ||
|
|
a46bb152af | ||
|
|
bbfa7c6c22 | ||
|
|
1cd54a48e9 | ||
|
|
2d950eecdf | ||
|
|
b143e46eb0 | ||
|
|
8fda856e24 | ||
|
|
54e63ccf9b | ||
|
|
ee53db1e35 | ||
|
|
fc502b920b | ||
|
|
20eae82f11 | ||
|
|
d2fc530316 | ||
|
|
7ac5555a84 | ||
|
|
15d397d8a6 | ||
|
|
b471adfb09 | ||
|
|
d7a38363e6 | ||
|
|
90def8f9b5 | ||
|
|
b126db453b | ||
|
|
601d357456 | ||
|
|
3a2024ebd7 | ||
|
|
6cd451acec | ||
|
|
3b6c12abd4 | ||
|
|
d9dfc584e7 | ||
|
|
57fa68970a | ||
|
|
fa14f1dadf | ||
|
|
9689607409 | ||
|
|
d75f871541 | ||
|
|
45895067c6 | ||
|
|
521f06dcc1 | ||
|
|
5b6a3a4c6f | ||
|
|
be497a68de | ||
|
|
c872a3b3f6 | ||
|
|
e0ae0f8e7b | ||
|
|
ad4ca32873 | ||
|
|
24100c4cbe | ||
|
|
e3a792d50d | ||
|
|
440d085c6d | ||
|
|
270ea9f6ca | ||
|
|
7a156d7d15 | ||
|
|
4c45e6cf3d | ||
|
|
704bc27dba | ||
|
|
b267572b38 | ||
|
|
5cad0d6be1 | ||
|
|
56d8dc865f | ||
|
|
d57c1d6d44 | ||
|
|
02fa7fbe2e | ||
|
|
07689954bf | ||
|
|
a7ea20b117 | ||
|
|
43fecdf60f | ||
|
|
31239684c7 | ||
|
|
5528ac8bf1 | ||
|
|
411e23ecfe | ||
|
|
7bf231643b | ||
|
|
2326160f2f | ||
|
|
68fe7e8406 | ||
|
|
c7bad63869 | ||
|
|
69319c6b41 | ||
|
|
9df381d3d1 | ||
|
|
0af7f64bca | ||
|
|
f73cbde7a5 | ||
|
|
0645a738ad | ||
|
|
d52cd11322 | ||
|
|
d3d08022cc | ||
|
|
21c8b9f8e7 | ||
|
|
6c55d8f139 | ||
|
|
ccdb2a3f70 | ||
|
|
f5ef9b917e | ||
|
|
a5443d5ca4 | ||
|
|
2c7d95bba2 | ||
|
|
8a2cdbfaa3 | ||
|
|
c94be0df35 | ||
|
|
4b6a976747 | ||
|
|
0043fdf859 | ||
|
|
24e62e18fa | ||
|
|
663dbbb476 | ||
|
|
471427a439 | ||
|
|
a777c4b00f | ||
|
|
dcc4cdd316 | ||
|
|
9c22701940 | ||
|
|
a77a924320 | ||
|
|
95dbf71939 | ||
|
|
8869e33a20 | ||
|
|
c94e1b02d2 | ||
|
|
42d29b626b | ||
|
|
b65a5ac283 | ||
|
|
ba48ff5965 | ||
|
|
b3a342bc44 | ||
|
|
9927803497 | ||
|
|
f0c604a9f1 | ||
|
|
8a56389396 | ||
|
|
9f7bfc76db | ||
|
|
a7a5501ea5 | ||
|
|
c401c4ef87 | ||
|
|
8ffb42962a |
4
.gitignore
vendored
4
.gitignore
vendored
@@ -1,3 +1,5 @@
|
||||
|
||||
dist/
|
||||
gon.hcl
|
||||
|
||||
/Build
|
||||
.DS_Store
|
||||
|
||||
@@ -9,6 +9,8 @@ builds:
|
||||
- -trimpath
|
||||
ldflags:
|
||||
- -s -w
|
||||
- -X main.version={{.Version}}
|
||||
- -X main.commit={{.Commit}}
|
||||
goos:
|
||||
- darwin
|
||||
goarch:
|
||||
|
||||
@@ -9,11 +9,15 @@ builds:
|
||||
- -trimpath
|
||||
ldflags:
|
||||
- -s -w
|
||||
- -X main.version={{.Version}}
|
||||
- -X main.commit={{.Commit}}
|
||||
goos:
|
||||
- darwin
|
||||
- linux
|
||||
- windows
|
||||
goarch:
|
||||
- 386
|
||||
- arm
|
||||
- amd64
|
||||
- arm64
|
||||
tags:
|
||||
|
||||
@@ -9,6 +9,8 @@ builds:
|
||||
- -trimpath
|
||||
ldflags:
|
||||
- -s -w
|
||||
- -X main.version={{.Version}}
|
||||
- -X main.commit={{.Commit}}
|
||||
goos:
|
||||
- linux
|
||||
- freebsd
|
||||
@@ -17,6 +19,7 @@ builds:
|
||||
- 386
|
||||
- arm
|
||||
- mips
|
||||
- mipsle
|
||||
- amd64
|
||||
- arm64
|
||||
goarm:
|
||||
|
||||
96
README.md
96
README.md
@@ -9,6 +9,10 @@ A highly configurable DNS forwarding proxy with support for:
|
||||
- Multiple upstreams with fallbacks
|
||||
- Multiple network policy driven DNS query steering
|
||||
- Policy driven domain based "split horizon" DNS with wildcard support
|
||||
- Integrations with common router vendors and firmware
|
||||
|
||||
## TLDR
|
||||
Proxy legacy DNS traffic to secure DNS upstreams in highly configurable ways.
|
||||
|
||||
All DNS protocols are supported, including:
|
||||
- `UDP 53`
|
||||
@@ -17,23 +21,41 @@ All DNS protocols are supported, including:
|
||||
- `DNS-over-HTTP/3` (DOH3)
|
||||
- `DNS-over-QUIC`
|
||||
|
||||
## Use Cases
|
||||
# Use Cases
|
||||
1. Use secure DNS protocols on networks and devices that don't natively support them (legacy routers, legacy OSes, TVs, smart toasters).
|
||||
2. Create source IP based DNS routing policies with variable secure DNS upstreams. Subnet 1 (admin) uses upstream resolver A, while Subnet 2 (employee) uses upstream resolver B.
|
||||
3. Create destination IP based DNS routing policies with variable secure DNS upstreams. Listener 1 uses upstream resolver C, while Listener 2 uses upstream resolver D.
|
||||
4. Create domain level "split horizon" DNS routing policies to send internal domains (*.company.int) to a local DNS server, while everything else goes to another upstream.
|
||||
5. Deploy on a router and create LAN client specific DNS routing policies from a web GUI (When using ControlD.com).
|
||||
|
||||
|
||||
## OS Support
|
||||
- Windows (386, amd64, arm)
|
||||
- Mac (amd64, arm64)
|
||||
- Linux (386, amd64, arm, mips)
|
||||
- FreeBSD
|
||||
- Common routers (See Router Mode below)
|
||||
|
||||
## Download
|
||||
Download pre-compiled binaries from the [Releases](https://github.com/Control-D-Inc/ctrld/releases) section.
|
||||
# Install
|
||||
There are several ways to download and install `ctrld.
|
||||
|
||||
## Quick Install
|
||||
The simplest way to download and install `ctrld` is to use the following installer command on any UNIX-like platform:
|
||||
|
||||
```shell
|
||||
sh -c 'sh -c "$(curl -sL https://api.controld.com/dl)"'
|
||||
```
|
||||
|
||||
Windows user and prefer Powershell (who doesn't)? No problem, execute this command instead in administrative cmd:
|
||||
```shell
|
||||
powershell -Command "(Invoke-WebRequest -Uri 'https://api.controld.com/dl' -UseBasicParsing).Content | Set-Content 'ctrld_install.bat'" && ctrld_install.bat
|
||||
```
|
||||
|
||||
## Download Manually
|
||||
Alternatively, if you know what you're doing you can download pre-compiled binaries from the [Releases](https://github.com/Control-D-Inc/ctrld/releases) section for the appropriate platform.
|
||||
|
||||
## Build
|
||||
`ctrld` requires `go1.19+`:
|
||||
Lastly, you can build `ctrld` from source which requires `go1.19+`:
|
||||
|
||||
```shell
|
||||
$ go build ./cmd/ctrld
|
||||
@@ -45,6 +67,10 @@ or
|
||||
$ go install github.com/Control-D-Inc/ctrld/cmd/ctrld@latest
|
||||
```
|
||||
|
||||
|
||||
# Usage
|
||||
The cli is self documenting, so free free to run `--help` on any sub-command to get specific usages.
|
||||
|
||||
## Arguments
|
||||
```
|
||||
__ .__ .___
|
||||
@@ -60,18 +86,29 @@ Usage:
|
||||
Available Commands:
|
||||
run Run the DNS proxy server
|
||||
service Manage ctrld service
|
||||
start Quick start service and configure DNS on default interface
|
||||
stop Quick stop service and remove DNS from default interface
|
||||
start Quick start service and configure DNS on interface
|
||||
stop Quick stop service and remove DNS from interface
|
||||
setup Auto-setup Control D on a router.
|
||||
|
||||
Supported platforms:
|
||||
|
||||
ₒ ddwrt
|
||||
ₒ merlin
|
||||
ₒ openwrt
|
||||
ₒ ubios
|
||||
ₒ auto - detect the platform you are running on
|
||||
|
||||
|
||||
Flags:
|
||||
-h, --help help for ctrld
|
||||
-s, --silent do not write any log output
|
||||
-v, --verbose count verbose log output, "-v" basic logging, "-vv" debug level logging
|
||||
--version version for ctrld
|
||||
|
||||
Use "ctrld [command] --help" for more information about a command.
|
||||
```
|
||||
|
||||
## Usage
|
||||
## Basic Run Mode
|
||||
To start the server with default configuration, simply run: `./ctrld run`. This will create a generic `ctrld.toml` file in the **working directory** and start the application in foreground.
|
||||
1. Start the server
|
||||
```
|
||||
@@ -85,12 +122,12 @@ To start the server with default configuration, simply run: `./ctrld run`. This
|
||||
147.185.34.1
|
||||
```
|
||||
|
||||
If `verify.controld.com` resolves, you're successfully using the default Control D upstream.
|
||||
If `verify.controld.com` resolves, you're successfully using the default Control D upstream. From here, you can start editing the config file and go nuts with it. To enforce a new config, restart the server.
|
||||
|
||||
### Service Mode
|
||||
To run the application in service mode, simply run: `./ctrld start` as system/root user. This will create a generic `ctrld.toml` file in the **user home** directory, start the system service, and configure the listener on the default interface. Service will start on OS boot.
|
||||
## Service Mode
|
||||
To run the application in service mode on any Windows, MacOS or Linux distibution, simply run: `./ctrld start` as system/root user. This will create a generic `ctrld.toml` file in the **user home** directory (on Windows) or `/etc/controld/` (everywhere else), start the system service, and configure the listener on the default network interface. Service will start on OS boot.
|
||||
|
||||
In order to stop the service, and restore your DNS to original state, simply run `./ctrld stop`.
|
||||
In order to stop the service, and restore your DNS to original state, simply run `./ctrld stop`. If you wish to uninstall the service permanently, run `./ctrld service uninstall`.
|
||||
|
||||
For granular control of the service, run the `service` command. Each sub-command has its own help section so you can see what arguments you can supply.
|
||||
|
||||
@@ -117,8 +154,23 @@ For granular control of the service, run the `service` command. Each sub-command
|
||||
Use "ctrld service [command] --help" for more information about a command.
|
||||
```
|
||||
|
||||
## Router Mode
|
||||
You can run `ctrld` on any supported router, which will function similarly to the Service Mode mentioned above. The list of supported routers and firmware includes:
|
||||
- Asus Merlin
|
||||
- DD-WRT
|
||||
- FreshTomato
|
||||
- GL.iNet
|
||||
- OpenWRT
|
||||
- pfSense
|
||||
- Synology
|
||||
- Ubiquiti (UniFi, EdgeOS)
|
||||
|
||||
In order to start `ctrld` as a DNS provider, simply run `./ctrld setup auto` command.
|
||||
|
||||
In this mode, and when Control D upstreams are used, the router will [relay your network topology](https://docs.controld.com/docs/device-clients) to Control D (LAN IPs, MAC addresses, and hostnames), and you will be able to see your LAN devices in the web panel, view analytics and apply unique profiles to them.
|
||||
|
||||
### Control D Auto Configuration
|
||||
Application can be started with a specific resolver config, instead of the default one. Simply supply your resolver ID with a `--cd` flag, when using the `run` (foreground) or `start` (service) modes.
|
||||
Application can be started with a specific resolver config, instead of the default one. Simply supply your Resolver ID with a `--cd` flag, when using the `run` (foreground) or `start` (service) or `setup` (router) modes.
|
||||
|
||||
The following command will start the application in foreground mode, using the free "p2" resolver, which blocks Ads & Trackers.
|
||||
|
||||
@@ -126,22 +178,27 @@ The following command will start the application in foreground mode, using the f
|
||||
./ctrld run --cd p2
|
||||
```
|
||||
|
||||
Alternatively, you can use your own personal Control D Device resolver, and start the application in service mode. Your resolver ID is the part after the slash of your DNS-over-HTTPS resolver. ie. https://dns.controld.com/abcd1234
|
||||
Alternatively, you can use your own personal Control D Device resolver, and start the application in service mode. Your resolver ID is displayed on the "Show Resolvers" screen for the relevant Control D Device.
|
||||
|
||||
```shell
|
||||
./ctrld start --cd abcd1234
|
||||
```
|
||||
|
||||
Once you run the above command, the following things will happen:
|
||||
You can do the same while starting in router mode:
|
||||
```shell
|
||||
./ctrld setup auto --cd abcd1234
|
||||
```
|
||||
|
||||
Once you run the above commands (in service or router modes only), the following things will happen:
|
||||
- You resolver configuration will be fetched from the API, and config file templated with the resolver data
|
||||
- Application will start as a service, and keep running (even after reboot) until you run the `stop` or `service uninstall` sub-commands
|
||||
- Your default network interface will be updated to use the listener started by the service
|
||||
- All OS DNS queries will be sent to the listener
|
||||
|
||||
## Configuration
|
||||
# Configuration
|
||||
See [Configuration Docs](docs/config.md).
|
||||
|
||||
### Example
|
||||
## Example
|
||||
- Start `listener.0` on 127.0.0.1:53
|
||||
- Accept queries from any source address
|
||||
- Send all queries to `upstream.0` via DoH protocol
|
||||
@@ -183,17 +240,16 @@ See [Configuration Docs](docs/config.md).
|
||||
|
||||
```
|
||||
|
||||
### Advanced
|
||||
## Advanced Configuration
|
||||
The above is the most basic example, which will work out of the box. If you're looking to do advanced configurations using policies, see [Configuration Docs](docs/config.md) for complete documentation of the config file.
|
||||
|
||||
You can also supply configuration via launch argeuments, in [Ephemeral Mode](docs/ephemeral_mode.md).
|
||||
|
||||
## Contributing
|
||||
|
||||
See [Contribution Guideline](./docs/contributing.md)
|
||||
|
||||
## Roadmap
|
||||
The following functionality is on the roadmap and will be available in future releases.
|
||||
- Router self-installation
|
||||
- Client hostname/MAC passthrough
|
||||
- Prometheus metrics exporter
|
||||
- DNS intercept mode
|
||||
- Support for more routers (let us know which ones)
|
||||
|
||||
11
client_info.go
Normal file
11
client_info.go
Normal file
@@ -0,0 +1,11 @@
|
||||
package ctrld
|
||||
|
||||
// ClientInfoCtxKey is the context key to store client info.
|
||||
type ClientInfoCtxKey struct{}
|
||||
|
||||
// ClientInfo represents ctrld's clients information.
|
||||
type ClientInfo struct {
|
||||
Mac string
|
||||
IP string
|
||||
Hostname string
|
||||
}
|
||||
531
cmd/ctrld/cli.go
531
cmd/ctrld/cli.go
@@ -3,10 +3,10 @@ package main
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/x509"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
@@ -15,10 +15,11 @@ import (
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/cuonglm/osinfo"
|
||||
"github.com/fsnotify/fsnotify"
|
||||
|
||||
"github.com/go-playground/validator/v10"
|
||||
"github.com/kardianos/service"
|
||||
"github.com/miekg/dns"
|
||||
@@ -29,16 +30,22 @@ import (
|
||||
"tailscale.com/net/interfaces"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld"
|
||||
"github.com/Control-D-Inc/ctrld/internal/certs"
|
||||
"github.com/Control-D-Inc/ctrld/internal/controld"
|
||||
ctrldnet "github.com/Control-D-Inc/ctrld/internal/net"
|
||||
"github.com/Control-D-Inc/ctrld/internal/router"
|
||||
)
|
||||
|
||||
const selfCheckFQDN = "verify.controld.com"
|
||||
var (
|
||||
version = "dev"
|
||||
commit = "none"
|
||||
)
|
||||
|
||||
var (
|
||||
v = viper.NewWithOptions(viper.KeyDelimiter("::"))
|
||||
defaultConfigWritten = false
|
||||
defaultConfigFile = "ctrld.toml"
|
||||
rootCertPool *x509.CertPool
|
||||
)
|
||||
|
||||
var basicModeFlags = []string{"listen", "primary_upstream", "secondary_upstream", "domains"}
|
||||
@@ -61,23 +68,44 @@ _/ ___\ __\_ __ \ | / __ |
|
||||
\/ dns forwarding proxy \/
|
||||
`
|
||||
|
||||
var rootCmd = &cobra.Command{
|
||||
Use: "ctrld",
|
||||
Short: strings.TrimLeft(rootShortDesc, "\n"),
|
||||
Version: curVersion(),
|
||||
PreRun: func(cmd *cobra.Command, args []string) {
|
||||
initConsoleLogging()
|
||||
},
|
||||
}
|
||||
|
||||
func curVersion() string {
|
||||
if version != "dev" && !strings.HasPrefix(version, "v") {
|
||||
version = "v" + version
|
||||
}
|
||||
if len(commit) > 7 {
|
||||
commit = commit[:7]
|
||||
}
|
||||
return fmt.Sprintf("%s-%s", version, commit)
|
||||
}
|
||||
|
||||
func initCLI() {
|
||||
// Enable opening via explorer.exe on Windows.
|
||||
// See: https://github.com/spf13/cobra/issues/844.
|
||||
cobra.MousetrapHelpText = ""
|
||||
cobra.EnableCommandSorting = false
|
||||
|
||||
rootCmd := &cobra.Command{
|
||||
Use: "ctrld",
|
||||
Short: strings.TrimLeft(rootShortDesc, "\n"),
|
||||
Version: "1.1.2",
|
||||
}
|
||||
rootCmd.PersistentFlags().CountVarP(
|
||||
&verbose,
|
||||
"verbose",
|
||||
"v",
|
||||
`verbose log output, "-v" basic logging, "-vv" debug level logging`,
|
||||
)
|
||||
rootCmd.PersistentFlags().BoolVarP(
|
||||
&silent,
|
||||
"silent",
|
||||
"s",
|
||||
false,
|
||||
`do not write any log output`,
|
||||
)
|
||||
rootCmd.SetHelpCommand(&cobra.Command{Hidden: true})
|
||||
rootCmd.CompletionOptions.HiddenDefaultCmd = true
|
||||
|
||||
@@ -85,45 +113,70 @@ func initCLI() {
|
||||
Use: "run",
|
||||
Short: "Run the DNS proxy server",
|
||||
Args: cobra.NoArgs,
|
||||
PreRun: func(cmd *cobra.Command, args []string) {
|
||||
initConsoleLogging()
|
||||
},
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
if daemon && runtime.GOOS == "windows" {
|
||||
log.Fatal("Cannot run in daemon mode. Please install a Windows service.")
|
||||
mainLog.Fatal().Msg("Cannot run in daemon mode. Please install a Windows service.")
|
||||
}
|
||||
|
||||
waitCh := make(chan struct{})
|
||||
stopCh := make(chan struct{})
|
||||
if !daemon {
|
||||
// We need to call s.Run() as soon as possible to response to the OS manager, so it
|
||||
// can see ctrld is running and don't mark ctrld as failed service.
|
||||
go func() {
|
||||
p := &prog{
|
||||
waitCh: waitCh,
|
||||
stopCh: stopCh,
|
||||
}
|
||||
s, err := service.New(p, svcConfig)
|
||||
if err != nil {
|
||||
mainLog.Fatal().Err(err).Msg("failed create new service")
|
||||
}
|
||||
s = newService(s)
|
||||
if err := s.Run(); err != nil {
|
||||
mainLog.Error().Err(err).Msg("failed to start service")
|
||||
}
|
||||
}()
|
||||
}
|
||||
noConfigStart := isNoConfigStart(cmd)
|
||||
writeDefaultConfig := !noConfigStart && configBase64 == ""
|
||||
configs := []struct {
|
||||
name string
|
||||
written bool
|
||||
}{
|
||||
// For compatibility, we check for config.toml first, but only read it if exists.
|
||||
{"config", false},
|
||||
{"ctrld", writeDefaultConfig},
|
||||
}
|
||||
for _, config := range configs {
|
||||
ctrld.SetConfigName(v, config.name)
|
||||
v.SetConfigFile(configPath)
|
||||
if readConfigFile(config.written) {
|
||||
break
|
||||
}
|
||||
}
|
||||
tryReadingConfig(writeDefaultConfig)
|
||||
|
||||
readBase64Config()
|
||||
readBase64Config(configBase64)
|
||||
processNoConfigFlags(noConfigStart)
|
||||
if err := v.Unmarshal(&cfg); err != nil {
|
||||
log.Fatalf("failed to unmarshal config: %v", err)
|
||||
mainLog.Fatal().Msgf("failed to unmarshal config: %v", err)
|
||||
}
|
||||
fmt.Println("starting ctrld...")
|
||||
|
||||
mainLog.Info().Msgf("starting ctrld %s", curVersion())
|
||||
oi := osinfo.New()
|
||||
mainLog.Info().Msgf("os: %s", oi.String())
|
||||
|
||||
// Wait for network up.
|
||||
if !ctrldnet.Up() {
|
||||
log.Fatal("network is not up yet")
|
||||
mainLog.Fatal().Msg("network is not up yet")
|
||||
}
|
||||
processLogAndCacheFlags()
|
||||
// Log config do not have thing to validate, so it's safe to init log here,
|
||||
// so it's able to log information in processCDFlags.
|
||||
initLogging()
|
||||
|
||||
if setupRouter {
|
||||
s, errCh := runDNSServerForNTPD(router.ListenAddress())
|
||||
if err := router.PreRun(); err != nil {
|
||||
mainLog.Fatal().Err(err).Msg("failed to perform router pre-start check")
|
||||
}
|
||||
if err := s.Shutdown(); err != nil && errCh != nil {
|
||||
mainLog.Fatal().Err(err).Msg("failed to shutdown dns server for ntpd")
|
||||
}
|
||||
}
|
||||
|
||||
processCDFlags()
|
||||
if err := ctrld.ValidateConfig(validator.New(), &cfg); err != nil {
|
||||
log.Fatalf("invalid config: %v", err)
|
||||
mainLog.Fatal().Msgf("invalid config: %v", err)
|
||||
}
|
||||
initCache()
|
||||
|
||||
@@ -149,22 +202,26 @@ func initCLI() {
|
||||
os.Exit(0)
|
||||
}
|
||||
|
||||
s, err := service.New(&prog{}, svcConfig)
|
||||
if err != nil {
|
||||
mainLog.Fatal().Err(err).Msg("failed create new service")
|
||||
}
|
||||
serviceLogger, err := s.Logger(nil)
|
||||
if err != nil {
|
||||
mainLog.Error().Err(err).Msg("failed to get service logger")
|
||||
return
|
||||
if setupRouter {
|
||||
switch platform := router.Name(); {
|
||||
case platform == router.DDWrt:
|
||||
rootCertPool = certs.CACertPool()
|
||||
fallthrough
|
||||
case platform != "":
|
||||
mainLog.Debug().Msg("Router setup")
|
||||
err := router.Configure(&cfg)
|
||||
if errors.Is(err, router.ErrNotSupported) {
|
||||
unsupportedPlatformHelp(cmd)
|
||||
os.Exit(1)
|
||||
}
|
||||
if err != nil {
|
||||
mainLog.Fatal().Err(err).Msg("failed to configure router")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err := s.Run(); err != nil {
|
||||
if sErr := serviceLogger.Error(err); sErr != nil {
|
||||
mainLog.Error().Err(sErr).Msg("failed to write service log")
|
||||
}
|
||||
mainLog.Error().Err(err).Msg("failed to start service")
|
||||
}
|
||||
close(waitCh)
|
||||
<-stopCh
|
||||
},
|
||||
}
|
||||
runCmd.Flags().BoolVarP(&daemon, "daemon", "d", false, "Run as daemon")
|
||||
@@ -177,18 +234,25 @@ func initCLI() {
|
||||
runCmd.Flags().StringVarP(&logPath, "log", "", "", "Path to log file")
|
||||
runCmd.Flags().IntVarP(&cacheSize, "cache_size", "", 0, "Enable cache with size items")
|
||||
runCmd.Flags().StringVarP(&cdUID, "cd", "", "", "Control D resolver uid")
|
||||
runCmd.Flags().BoolVarP(&cdDev, "dev", "", false, "Use Control D dev resolver/domain")
|
||||
_ = runCmd.Flags().MarkHidden("dev")
|
||||
runCmd.Flags().StringVarP(&homedir, "homedir", "", "", "")
|
||||
_ = runCmd.Flags().MarkHidden("homedir")
|
||||
runCmd.Flags().StringVarP(&iface, "iface", "", "", `Update DNS setting for iface, "auto" means the default interface gateway`)
|
||||
_ = runCmd.Flags().MarkHidden("iface")
|
||||
runCmd.Flags().BoolVarP(&setupRouter, "router", "", false, `setup for running on router platforms`)
|
||||
_ = runCmd.Flags().MarkHidden("router")
|
||||
|
||||
rootCmd.AddCommand(runCmd)
|
||||
|
||||
startCmd := &cobra.Command{
|
||||
PreRun: checkHasElevatedPrivilege,
|
||||
Use: "start",
|
||||
Short: "Install and start the ctrld service",
|
||||
Args: cobra.NoArgs,
|
||||
PreRun: func(cmd *cobra.Command, args []string) {
|
||||
initConsoleLogging()
|
||||
checkHasElevatedPrivilege()
|
||||
},
|
||||
Use: "start",
|
||||
Short: "Install and start the ctrld service",
|
||||
Args: cobra.NoArgs,
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
sc := &service.Config{}
|
||||
*sc = *svcConfig
|
||||
@@ -198,6 +262,9 @@ func initCLI() {
|
||||
}
|
||||
setDependencies(sc)
|
||||
sc.Arguments = append([]string{"run"}, osArgs...)
|
||||
if err := router.ConfigureService(sc); err != nil {
|
||||
mainLog.Fatal().Err(err).Msg("failed to configure service on router")
|
||||
}
|
||||
|
||||
// No config path, generating config in HOME directory.
|
||||
noConfigStart := isNoConfigStart(cmd)
|
||||
@@ -205,18 +272,18 @@ func initCLI() {
|
||||
if configPath != "" {
|
||||
v.SetConfigFile(configPath)
|
||||
}
|
||||
if dir, err := os.UserHomeDir(); err == nil {
|
||||
if dir, err := userHomeDir(); err == nil {
|
||||
setWorkingDirectory(sc, dir)
|
||||
if configPath == "" && writeDefaultConfig {
|
||||
defaultConfigFile = filepath.Join(dir, defaultConfigFile)
|
||||
v.SetConfigFile(defaultConfigFile)
|
||||
}
|
||||
sc.Arguments = append(sc.Arguments, "--homedir="+dir)
|
||||
}
|
||||
|
||||
readConfigFile(writeDefaultConfig && cdUID == "")
|
||||
tryReadingConfig(writeDefaultConfig)
|
||||
|
||||
if err := v.Unmarshal(&cfg); err != nil {
|
||||
log.Fatalf("failed to unmarshal config: %v", err)
|
||||
mainLog.Fatal().Msgf("failed to unmarshal config: %v", err)
|
||||
}
|
||||
|
||||
logPath := cfg.Service.LogPath
|
||||
@@ -225,20 +292,25 @@ func initCLI() {
|
||||
cfg.Service.LogPath = logPath
|
||||
|
||||
processCDFlags()
|
||||
// On Windows, the service will be run as SYSTEM, so if ctrld start as Admin,
|
||||
// the user home dir is different, so pass specific arguments that relevant here.
|
||||
if runtime.GOOS == "windows" {
|
||||
if configPath == "" {
|
||||
sc.Arguments = append(sc.Arguments, "--config="+defaultConfigFile)
|
||||
}
|
||||
|
||||
if err := ctrld.ValidateConfig(validator.New(), &cfg); err != nil {
|
||||
mainLog.Fatal().Msgf("invalid config: %v", err)
|
||||
}
|
||||
|
||||
// Explicitly passing config, so on system where home directory could not be obtained,
|
||||
// or sub-process env is different with the parent, we still behave correctly and use
|
||||
// the expected config file.
|
||||
if configPath == "" {
|
||||
sc.Arguments = append(sc.Arguments, "--config="+defaultConfigFile)
|
||||
}
|
||||
|
||||
prog := &prog{}
|
||||
s, err := service.New(prog, sc)
|
||||
if err != nil {
|
||||
stderrMsg(err.Error())
|
||||
mainLog.Error().Msg(err.Error())
|
||||
return
|
||||
}
|
||||
s = newService(s)
|
||||
tasks := []task{
|
||||
{s.Stop, false},
|
||||
{s.Uninstall, false},
|
||||
@@ -246,16 +318,21 @@ func initCLI() {
|
||||
{s.Start, true},
|
||||
}
|
||||
if doTasks(tasks) {
|
||||
status, err := s.Status()
|
||||
if err := router.PostInstall(svcConfig); err != nil {
|
||||
mainLog.Warn().Err(err).Msg("post installation failed, please check system/service log for details error")
|
||||
return
|
||||
}
|
||||
status, err := serviceStatus(s)
|
||||
if err != nil {
|
||||
mainLog.Warn().Err(err).Msg("could not get service status")
|
||||
return
|
||||
}
|
||||
|
||||
status = selfCheckStatus(status)
|
||||
domain := cfg.Upstream["0"].VerifyDomain()
|
||||
status = selfCheckStatus(status, domain)
|
||||
switch status {
|
||||
case service.StatusRunning:
|
||||
mainLog.Info().Msg("Service started")
|
||||
mainLog.Notice().Msg("Service started")
|
||||
default:
|
||||
mainLog.Error().Msg("Service did not start, please check system/service log for details error")
|
||||
if runtime.GOOS == "linux" {
|
||||
@@ -277,43 +354,55 @@ func initCLI() {
|
||||
startCmd.Flags().StringVarP(&logPath, "log", "", "", "Path to log file")
|
||||
startCmd.Flags().IntVarP(&cacheSize, "cache_size", "", 0, "Enable cache with size items")
|
||||
startCmd.Flags().StringVarP(&cdUID, "cd", "", "", "Control D resolver uid")
|
||||
startCmd.Flags().BoolVarP(&cdDev, "dev", "", false, "Use Control D dev resolver/domain")
|
||||
_ = startCmd.Flags().MarkHidden("dev")
|
||||
startCmd.Flags().StringVarP(&iface, "iface", "", "", `Update DNS setting for iface, "auto" means the default interface gateway`)
|
||||
startCmd.Flags().BoolVarP(&setupRouter, "router", "", false, `setup for running on router platforms`)
|
||||
_ = startCmd.Flags().MarkHidden("router")
|
||||
|
||||
stopCmd := &cobra.Command{
|
||||
PreRun: checkHasElevatedPrivilege,
|
||||
Use: "stop",
|
||||
Short: "Stop the ctrld service",
|
||||
Args: cobra.NoArgs,
|
||||
PreRun: func(cmd *cobra.Command, args []string) {
|
||||
initConsoleLogging()
|
||||
checkHasElevatedPrivilege()
|
||||
},
|
||||
Use: "stop",
|
||||
Short: "Stop the ctrld service",
|
||||
Args: cobra.NoArgs,
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
prog := &prog{}
|
||||
s, err := service.New(prog, svcConfig)
|
||||
if err != nil {
|
||||
stderrMsg(err.Error())
|
||||
mainLog.Error().Msg(err.Error())
|
||||
return
|
||||
}
|
||||
s = newService(s)
|
||||
initLogging()
|
||||
if doTasks([]task{{s.Stop, true}}) {
|
||||
prog.resetDNS()
|
||||
mainLog.Info().Msg("Service stopped")
|
||||
mainLog.Notice().Msg("Service stopped")
|
||||
}
|
||||
},
|
||||
}
|
||||
stopCmd.Flags().StringVarP(&iface, "iface", "", "", `Reset DNS setting for iface, "auto" means the default interface gateway`)
|
||||
|
||||
restartCmd := &cobra.Command{
|
||||
PreRun: checkHasElevatedPrivilege,
|
||||
Use: "restart",
|
||||
Short: "Restart the ctrld service",
|
||||
Args: cobra.NoArgs,
|
||||
PreRun: func(cmd *cobra.Command, args []string) {
|
||||
initConsoleLogging()
|
||||
checkHasElevatedPrivilege()
|
||||
},
|
||||
Use: "restart",
|
||||
Short: "Restart the ctrld service",
|
||||
Args: cobra.NoArgs,
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
s, err := service.New(&prog{}, svcConfig)
|
||||
if err != nil {
|
||||
stderrMsg(err.Error())
|
||||
mainLog.Error().Msg(err.Error())
|
||||
return
|
||||
}
|
||||
s = newService(s)
|
||||
initLogging()
|
||||
if doTasks([]task{{s.Restart, true}}) {
|
||||
stdoutMsg("Service restarted")
|
||||
mainLog.Notice().Msg("Service restarted")
|
||||
}
|
||||
},
|
||||
}
|
||||
@@ -322,41 +411,58 @@ func initCLI() {
|
||||
Use: "status",
|
||||
Short: "Show status of the ctrld service",
|
||||
Args: cobra.NoArgs,
|
||||
PreRun: func(cmd *cobra.Command, args []string) {
|
||||
initConsoleLogging()
|
||||
},
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
s, err := service.New(&prog{}, svcConfig)
|
||||
if err != nil {
|
||||
stderrMsg(err.Error())
|
||||
mainLog.Error().Msg(err.Error())
|
||||
return
|
||||
}
|
||||
status, err := s.Status()
|
||||
s = newService(s)
|
||||
status, err := serviceStatus(s)
|
||||
if err != nil {
|
||||
stderrMsg(err.Error())
|
||||
mainLog.Error().Msg(err.Error())
|
||||
os.Exit(1)
|
||||
}
|
||||
switch status {
|
||||
case service.StatusUnknown:
|
||||
stdoutMsg("Unknown status")
|
||||
mainLog.Notice().Msg("Unknown status")
|
||||
os.Exit(2)
|
||||
case service.StatusRunning:
|
||||
stdoutMsg("Service is running")
|
||||
mainLog.Notice().Msg("Service is running")
|
||||
os.Exit(0)
|
||||
case service.StatusStopped:
|
||||
stdoutMsg("Service is stopped")
|
||||
mainLog.Notice().Msg("Service is stopped")
|
||||
os.Exit(1)
|
||||
}
|
||||
},
|
||||
}
|
||||
if runtime.GOOS == "darwin" {
|
||||
// On darwin, running status command without privileges may return wrong information.
|
||||
statusCmd.PreRun = func(cmd *cobra.Command, args []string) {
|
||||
initConsoleLogging()
|
||||
checkHasElevatedPrivilege()
|
||||
}
|
||||
}
|
||||
|
||||
uninstallCmd := &cobra.Command{
|
||||
PreRun: checkHasElevatedPrivilege,
|
||||
Use: "uninstall",
|
||||
Short: "Stop and uninstall the ctrld service",
|
||||
Args: cobra.NoArgs,
|
||||
PreRun: func(cmd *cobra.Command, args []string) {
|
||||
initConsoleLogging()
|
||||
checkHasElevatedPrivilege()
|
||||
},
|
||||
Use: "uninstall",
|
||||
Short: "Stop and uninstall the ctrld service",
|
||||
Long: `Stop and uninstall the ctrld service.
|
||||
|
||||
NOTE: Uninstalling will set DNS to values provided by DHCP.`,
|
||||
Args: cobra.NoArgs,
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
prog := &prog{}
|
||||
s, err := service.New(prog, svcConfig)
|
||||
if err != nil {
|
||||
stderrMsg(err.Error())
|
||||
mainLog.Error().Msg(err.Error())
|
||||
return
|
||||
}
|
||||
tasks := []task{
|
||||
@@ -365,18 +471,28 @@ func initCLI() {
|
||||
}
|
||||
initLogging()
|
||||
if doTasks(tasks) {
|
||||
if iface == "" {
|
||||
iface = "auto"
|
||||
}
|
||||
prog.resetDNS()
|
||||
mainLog.Info().Msg("Service uninstalled")
|
||||
mainLog.Debug().Msg("Router cleanup")
|
||||
if err := router.Cleanup(svcConfig); err != nil {
|
||||
mainLog.Warn().Err(err).Msg("could not cleanup router")
|
||||
}
|
||||
mainLog.Notice().Msg("Service uninstalled")
|
||||
return
|
||||
}
|
||||
},
|
||||
}
|
||||
uninstallCmd.Flags().StringVarP(&iface, "iface", "", "auto", `Reset DNS setting for iface, "auto" means the default interface gateway`)
|
||||
uninstallCmd.Flags().StringVarP(&iface, "iface", "", "", `Reset DNS setting for iface, use "auto" for the default gateway interface`)
|
||||
|
||||
listIfacesCmd := &cobra.Command{
|
||||
Use: "list",
|
||||
Short: "List network interfaces of the host",
|
||||
Args: cobra.NoArgs,
|
||||
PreRun: func(cmd *cobra.Command, args []string) {
|
||||
initConsoleLogging()
|
||||
},
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
err := interfaces.ForeachInterface(func(i interfaces.Interface, prefixes []netip.Prefix) {
|
||||
fmt.Printf("Index : %d\n", i.Index)
|
||||
@@ -399,7 +515,7 @@ func initCLI() {
|
||||
println()
|
||||
})
|
||||
if err != nil {
|
||||
stderrMsg(err.Error())
|
||||
mainLog.Error().Msg(err.Error())
|
||||
}
|
||||
},
|
||||
}
|
||||
@@ -434,9 +550,12 @@ func initCLI() {
|
||||
serviceCmd.AddCommand(interfacesCmd)
|
||||
rootCmd.AddCommand(serviceCmd)
|
||||
startCmdAlias := &cobra.Command{
|
||||
PreRun: checkHasElevatedPrivilege,
|
||||
Use: "start",
|
||||
Short: "Quick start service and configure DNS on interface",
|
||||
PreRun: func(cmd *cobra.Command, args []string) {
|
||||
initConsoleLogging()
|
||||
checkHasElevatedPrivilege()
|
||||
},
|
||||
Use: "start",
|
||||
Short: "Quick start service and configure DNS on interface",
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
if !cmd.Flags().Changed("iface") {
|
||||
os.Args = append(os.Args, "--iface="+ifaceStartStop)
|
||||
@@ -449,9 +568,12 @@ func initCLI() {
|
||||
startCmdAlias.Flags().AddFlagSet(startCmd.Flags())
|
||||
rootCmd.AddCommand(startCmdAlias)
|
||||
stopCmdAlias := &cobra.Command{
|
||||
PreRun: checkHasElevatedPrivilege,
|
||||
Use: "stop",
|
||||
Short: "Quick stop service and remove DNS from interface",
|
||||
PreRun: func(cmd *cobra.Command, args []string) {
|
||||
initConsoleLogging()
|
||||
checkHasElevatedPrivilege()
|
||||
},
|
||||
Use: "stop",
|
||||
Short: "Quick stop service and remove DNS from interface",
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
if !cmd.Flags().Changed("iface") {
|
||||
os.Args = append(os.Args, "--iface="+ifaceStartStop)
|
||||
@@ -463,11 +585,6 @@ func initCLI() {
|
||||
stopCmdAlias.Flags().StringVarP(&ifaceStartStop, "iface", "", "auto", `Reset DNS setting for iface, "auto" means the default interface gateway`)
|
||||
stopCmdAlias.Flags().AddFlagSet(stopCmd.Flags())
|
||||
rootCmd.AddCommand(stopCmdAlias)
|
||||
|
||||
if err := rootCmd.Execute(); err != nil {
|
||||
stderrMsg(err.Error())
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
func writeConfigFile() error {
|
||||
@@ -500,15 +617,8 @@ func readConfigFile(writeDefaultConfig bool) bool {
|
||||
// If err == nil, there's a config supplied via `--config`, no default config written.
|
||||
err := v.ReadInConfig()
|
||||
if err == nil {
|
||||
fmt.Println("loading config file from:", v.ConfigFileUsed())
|
||||
mainLog.Info().Msg("loading config file from: " + v.ConfigFileUsed())
|
||||
defaultConfigFile = v.ConfigFileUsed()
|
||||
v.OnConfigChange(func(in fsnotify.Event) {
|
||||
if err := v.UnmarshalKey("listener", &cfg.Listener); err != nil {
|
||||
log.Printf("failed to unmarshal listener config: %v", err)
|
||||
return
|
||||
}
|
||||
})
|
||||
v.WatchConfig()
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -518,29 +628,36 @@ func readConfigFile(writeDefaultConfig bool) bool {
|
||||
|
||||
// If error is viper.ConfigFileNotFoundError, write default config.
|
||||
if _, ok := err.(viper.ConfigFileNotFoundError); ok {
|
||||
if err := v.Unmarshal(&cfg); err != nil {
|
||||
mainLog.Fatal().Msgf("failed to unmarshal default config: %v", err)
|
||||
}
|
||||
if err := writeConfigFile(); err != nil {
|
||||
log.Fatalf("failed to write default config file: %v", err)
|
||||
mainLog.Fatal().Msgf("failed to write default config file: %v", err)
|
||||
} else {
|
||||
fmt.Println("writing default config file to: " + defaultConfigFile)
|
||||
fp, err := filepath.Abs(defaultConfigFile)
|
||||
if err != nil {
|
||||
mainLog.Fatal().Msgf("failed to get default config file path: %v", err)
|
||||
}
|
||||
mainLog.Info().Msg("writing default config file to: " + fp)
|
||||
}
|
||||
defaultConfigWritten = true
|
||||
return false
|
||||
}
|
||||
// Otherwise, report fatal error and exit.
|
||||
log.Fatalf("failed to decode config file: %v", err)
|
||||
mainLog.Fatal().Msgf("failed to decode config file: %v", err)
|
||||
return false
|
||||
}
|
||||
|
||||
func readBase64Config() {
|
||||
func readBase64Config(configBase64 string) {
|
||||
if configBase64 == "" {
|
||||
return
|
||||
}
|
||||
configStr, err := base64.StdEncoding.DecodeString(configBase64)
|
||||
if err != nil {
|
||||
log.Fatalf("invalid base64 config: %v", err)
|
||||
mainLog.Fatal().Msgf("invalid base64 config: %v", err)
|
||||
}
|
||||
if err := v.ReadConfig(bytes.NewReader(configStr)); err != nil {
|
||||
log.Fatalf("failed to read base64 config: %v", err)
|
||||
mainLog.Fatal().Msgf("failed to read base64 config: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -549,22 +666,30 @@ func processNoConfigFlags(noConfigStart bool) {
|
||||
return
|
||||
}
|
||||
if listenAddress == "" || primaryUpstream == "" {
|
||||
log.Fatal(`"listen" and "primary_upstream" flags must be set in no config mode`)
|
||||
mainLog.Fatal().Msg(`"listen" and "primary_upstream" flags must be set in no config mode`)
|
||||
}
|
||||
processListenFlag()
|
||||
|
||||
endpointAndTyp := func(endpoint string) (string, string) {
|
||||
typ := ctrld.ResolverTypeFromEndpoint(endpoint)
|
||||
return strings.TrimPrefix(endpoint, "quic://"), typ
|
||||
}
|
||||
pEndpoint, pType := endpointAndTyp(primaryUpstream)
|
||||
upstream := map[string]*ctrld.UpstreamConfig{
|
||||
"0": {
|
||||
Name: primaryUpstream,
|
||||
Endpoint: primaryUpstream,
|
||||
Type: ctrld.ResolverTypeDOH,
|
||||
Name: pEndpoint,
|
||||
Endpoint: pEndpoint,
|
||||
Type: pType,
|
||||
Timeout: 5000,
|
||||
},
|
||||
}
|
||||
if secondaryUpstream != "" {
|
||||
sEndpoint, sType := endpointAndTyp(secondaryUpstream)
|
||||
upstream["1"] = &ctrld.UpstreamConfig{
|
||||
Name: secondaryUpstream,
|
||||
Endpoint: secondaryUpstream,
|
||||
Type: ctrld.ResolverTypeLegacy,
|
||||
Name: sEndpoint,
|
||||
Endpoint: sEndpoint,
|
||||
Type: sType,
|
||||
Timeout: 5000,
|
||||
}
|
||||
rules := make([]ctrld.Rule, 0, len(domains))
|
||||
for _, domain := range domains {
|
||||
@@ -585,7 +710,7 @@ func processCDFlags() {
|
||||
}
|
||||
logger := mainLog.With().Str("mode", "cd").Logger()
|
||||
logger.Info().Msgf("fetching Controld D configuration from API: %s", cdUID)
|
||||
resolverConfig, err := controld.FetchResolverConfig(cdUID)
|
||||
resolverConfig, err := controld.FetchResolverConfig(cdUID, rootCmd.Version, cdDev)
|
||||
if uer, ok := err.(*controld.UtilityErrorResponse); ok && uer.ErrorField.Code == controld.InvalidConfigCode {
|
||||
s, err := service.New(&prog{}, svcConfig)
|
||||
if err != nil {
|
||||
@@ -617,34 +742,57 @@ func processCDFlags() {
|
||||
return
|
||||
}
|
||||
|
||||
logger.Info().Msg("generating ctrld config from Controld-D configuration")
|
||||
cfg = ctrld.Config{}
|
||||
cfg.Network = make(map[string]*ctrld.NetworkConfig)
|
||||
cfg.Network["0"] = &ctrld.NetworkConfig{
|
||||
Name: "Network 0",
|
||||
Cidrs: []string{"0.0.0.0/0"},
|
||||
}
|
||||
cfg.Upstream = make(map[string]*ctrld.UpstreamConfig)
|
||||
cfg.Upstream["0"] = &ctrld.UpstreamConfig{
|
||||
Endpoint: resolverConfig.DOH,
|
||||
Type: ctrld.ResolverTypeDOH,
|
||||
Timeout: 5000,
|
||||
}
|
||||
rules := make([]ctrld.Rule, 0, len(resolverConfig.Exclude))
|
||||
for _, domain := range resolverConfig.Exclude {
|
||||
rules = append(rules, ctrld.Rule{domain: []string{}})
|
||||
}
|
||||
cfg.Listener = make(map[string]*ctrld.ListenerConfig)
|
||||
cfg.Listener["0"] = &ctrld.ListenerConfig{
|
||||
IP: "127.0.0.1",
|
||||
Port: 53,
|
||||
Policy: &ctrld.ListenerPolicyConfig{
|
||||
Name: "My Policy",
|
||||
Rules: rules,
|
||||
},
|
||||
logger.Info().Msg("generating ctrld config from Control-D configuration")
|
||||
if resolverConfig.Ctrld.CustomConfig != "" {
|
||||
logger.Info().Msg("using defined custom config of Control-D resolver")
|
||||
readBase64Config(resolverConfig.Ctrld.CustomConfig)
|
||||
if err := v.Unmarshal(&cfg); err != nil {
|
||||
mainLog.Fatal().Msgf("failed to unmarshal config: %v", err)
|
||||
}
|
||||
for _, listener := range cfg.Listener {
|
||||
if listener.IP == "" {
|
||||
listener.IP = randomLocalIP()
|
||||
}
|
||||
if listener.Port == 0 {
|
||||
listener.Port = 53
|
||||
}
|
||||
}
|
||||
// On router, we want to keep the listener address point to dnsmasq listener, aka 127.0.0.1:53.
|
||||
if router.Name() != "" {
|
||||
if lc := cfg.Listener["0"]; lc != nil {
|
||||
lc.IP = "127.0.0.1"
|
||||
lc.Port = 53
|
||||
}
|
||||
}
|
||||
} else {
|
||||
cfg = ctrld.Config{}
|
||||
cfg.Network = make(map[string]*ctrld.NetworkConfig)
|
||||
cfg.Network["0"] = &ctrld.NetworkConfig{
|
||||
Name: "Network 0",
|
||||
Cidrs: []string{"0.0.0.0/0"},
|
||||
}
|
||||
cfg.Upstream = make(map[string]*ctrld.UpstreamConfig)
|
||||
cfg.Upstream["0"] = &ctrld.UpstreamConfig{
|
||||
Endpoint: resolverConfig.DOH,
|
||||
Type: ctrld.ResolverTypeDOH,
|
||||
Timeout: 5000,
|
||||
}
|
||||
rules := make([]ctrld.Rule, 0, len(resolverConfig.Exclude))
|
||||
for _, domain := range resolverConfig.Exclude {
|
||||
rules = append(rules, ctrld.Rule{domain: []string{}})
|
||||
}
|
||||
cfg.Listener = make(map[string]*ctrld.ListenerConfig)
|
||||
cfg.Listener["0"] = &ctrld.ListenerConfig{
|
||||
IP: "127.0.0.1",
|
||||
Port: 53,
|
||||
Policy: &ctrld.ListenerPolicyConfig{
|
||||
Name: "My Policy",
|
||||
Rules: rules,
|
||||
},
|
||||
}
|
||||
processLogAndCacheFlags()
|
||||
}
|
||||
|
||||
processLogAndCacheFlags()
|
||||
if err := writeConfigFile(); err != nil {
|
||||
logger.Fatal().Err(err).Msg("failed to write config file")
|
||||
} else {
|
||||
@@ -658,11 +806,11 @@ func processListenFlag() {
|
||||
}
|
||||
host, portStr, err := net.SplitHostPort(listenAddress)
|
||||
if err != nil {
|
||||
log.Fatalf("invalid listener address: %v", err)
|
||||
mainLog.Fatal().Msgf("invalid listener address: %v", err)
|
||||
}
|
||||
port, err := strconv.Atoi(portStr)
|
||||
if err != nil {
|
||||
log.Fatalf("invalid port number: %v", err)
|
||||
mainLog.Fatal().Msgf("invalid port number: %v", err)
|
||||
}
|
||||
lc := &ctrld.ListenerConfig{
|
||||
IP: host,
|
||||
@@ -708,31 +856,104 @@ func netInterface(ifaceName string) (*net.Interface, error) {
|
||||
func defaultIfaceName() string {
|
||||
dri, err := interfaces.DefaultRouteInterface()
|
||||
if err != nil {
|
||||
// On WSL 1, the route table does not have any default route. But the fact that
|
||||
// it only uses /etc/resolv.conf for setup DNS, so we can use "lo" here.
|
||||
if oi := osinfo.New(); strings.Contains(oi.String(), "Microsoft") {
|
||||
return "lo"
|
||||
}
|
||||
mainLog.Fatal().Err(err).Msg("failed to get default route interface")
|
||||
}
|
||||
return dri
|
||||
}
|
||||
|
||||
func selfCheckStatus(status service.Status) service.Status {
|
||||
func selfCheckStatus(status service.Status, domain string) service.Status {
|
||||
if domain == "" {
|
||||
// Nothing to do, return the status as-is.
|
||||
return status
|
||||
}
|
||||
c := new(dns.Client)
|
||||
bo := backoff.NewBackoff("self-check", logf, 10*time.Second)
|
||||
bo.LogLongerThan = 500 * time.Millisecond
|
||||
ctx := context.Background()
|
||||
err := errors.New("query failed")
|
||||
maxAttempts := 20
|
||||
mainLog.Debug().Msg("Performing self-check")
|
||||
var (
|
||||
lcChanged map[string]*ctrld.ListenerConfig
|
||||
mu sync.Mutex
|
||||
)
|
||||
v.OnConfigChange(func(in fsnotify.Event) {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
if err := v.UnmarshalKey("listener", &lcChanged); err != nil {
|
||||
mainLog.Error().Msgf("failed to unmarshal listener config: %v", err)
|
||||
return
|
||||
}
|
||||
})
|
||||
v.WatchConfig()
|
||||
for i := 0; i < maxAttempts; i++ {
|
||||
lc := cfg.Listener["0"]
|
||||
mu.Lock()
|
||||
if lcChanged != nil {
|
||||
lc = lcChanged["0"]
|
||||
}
|
||||
mu.Unlock()
|
||||
m := new(dns.Msg)
|
||||
m.SetQuestion(selfCheckFQDN+".", dns.TypeA)
|
||||
m.SetQuestion(domain+".", dns.TypeA)
|
||||
m.RecursionDesired = true
|
||||
r, _, _ := c.ExchangeContext(ctx, m, net.JoinHostPort(lc.IP, strconv.Itoa(lc.Port)))
|
||||
r, _, err := c.ExchangeContext(ctx, m, net.JoinHostPort(lc.IP, strconv.Itoa(lc.Port)))
|
||||
if r != nil && r.Rcode == dns.RcodeSuccess && len(r.Answer) > 0 {
|
||||
mainLog.Debug().Msgf("self-check against %q succeeded", selfCheckFQDN)
|
||||
mainLog.Debug().Msgf("self-check against %q succeeded", domain)
|
||||
return status
|
||||
}
|
||||
bo.BackOff(ctx, err)
|
||||
bo.BackOff(ctx, fmt.Errorf("ExchangeContext: %w", err))
|
||||
}
|
||||
mainLog.Debug().Msgf("self-check against %q failed", selfCheckFQDN)
|
||||
mainLog.Debug().Msgf("self-check against %q failed", domain)
|
||||
return service.StatusUnknown
|
||||
}
|
||||
|
||||
func unsupportedPlatformHelp(cmd *cobra.Command) {
|
||||
mainLog.Error().Msg("Unsupported or incorrectly chosen router platform. Please open an issue and provide all relevant information: https://github.com/Control-D-Inc/ctrld/issues/new")
|
||||
}
|
||||
|
||||
func userHomeDir() (string, error) {
|
||||
switch router.Name() {
|
||||
case router.DDWrt, router.Merlin, router.Tomato:
|
||||
exe, err := os.Executable()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return filepath.Dir(exe), nil
|
||||
}
|
||||
// viper will expand for us.
|
||||
if runtime.GOOS == "windows" {
|
||||
return os.UserHomeDir()
|
||||
}
|
||||
dir := "/etc/controld"
|
||||
if err := os.MkdirAll(dir, 0750); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return dir, nil
|
||||
}
|
||||
|
||||
func tryReadingConfig(writeDefaultConfig bool) {
|
||||
configs := []struct {
|
||||
name string
|
||||
written bool
|
||||
}{
|
||||
// For compatibility, we check for config.toml first, but only read it if exists.
|
||||
{"config", false},
|
||||
{"ctrld", writeDefaultConfig},
|
||||
}
|
||||
|
||||
dir, err := userHomeDir()
|
||||
if err != nil {
|
||||
mainLog.Fatal().Msgf("failed to get config dir: %v", err)
|
||||
}
|
||||
for _, config := range configs {
|
||||
ctrld.SetConfigNameWithPath(v, config.name, dir)
|
||||
v.SetConfigFile(configPath)
|
||||
if readConfigFile(config.written) {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
100
cmd/ctrld/cli_router.go
Normal file
100
cmd/ctrld/cli_router.go
Normal file
@@ -0,0 +1,100 @@
|
||||
//go:build linux || freebsd
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"os"
|
||||
"os/exec"
|
||||
"strings"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld/internal/router"
|
||||
)
|
||||
|
||||
func initRouterCLI() {
|
||||
validArgs := append(router.SupportedPlatforms(), "auto")
|
||||
var b strings.Builder
|
||||
b.WriteString("Auto-setup Control D on a router.\n\nSupported platforms:\n\n")
|
||||
for _, arg := range validArgs {
|
||||
b.WriteString(" ₒ ")
|
||||
b.WriteString(arg)
|
||||
if arg == "auto" {
|
||||
b.WriteString(" - detect the platform you are running on")
|
||||
}
|
||||
b.WriteString("\n")
|
||||
}
|
||||
|
||||
routerCmd := &cobra.Command{
|
||||
Use: "setup",
|
||||
Short: b.String(),
|
||||
PreRun: func(cmd *cobra.Command, args []string) {
|
||||
initConsoleLogging()
|
||||
},
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
if len(args) == 0 {
|
||||
_ = cmd.Help()
|
||||
return
|
||||
}
|
||||
if len(args) != 1 {
|
||||
_ = cmd.Help()
|
||||
return
|
||||
}
|
||||
platform := args[0]
|
||||
if platform == "auto" {
|
||||
platform = router.Name()
|
||||
}
|
||||
if !router.IsSupported(platform) {
|
||||
unsupportedPlatformHelp(cmd)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
exe, err := os.Executable()
|
||||
if err != nil {
|
||||
mainLog.Fatal().Msgf("could not find executable path: %v", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
cmdArgs := []string{"start"}
|
||||
cmdArgs = append(cmdArgs, osArgs(platform)...)
|
||||
cmdArgs = append(cmdArgs, "--router")
|
||||
command := exec.Command(exe, cmdArgs...)
|
||||
command.Stdout = os.Stdout
|
||||
command.Stderr = os.Stderr
|
||||
command.Stdin = os.Stdin
|
||||
if err := command.Run(); err != nil {
|
||||
mainLog.Fatal().Msg(err.Error())
|
||||
}
|
||||
},
|
||||
}
|
||||
// Keep these flags in sync with startCmd, except for "--router".
|
||||
routerCmd.Flags().StringVarP(&configPath, "config", "c", "", "Path to config file")
|
||||
routerCmd.Flags().StringVarP(&configBase64, "base64_config", "", "", "Base64 encoded config")
|
||||
routerCmd.Flags().StringVarP(&listenAddress, "listen", "", "", "Listener address and port, in format: address:port")
|
||||
routerCmd.Flags().StringVarP(&primaryUpstream, "primary_upstream", "", "", "Primary upstream endpoint")
|
||||
routerCmd.Flags().StringVarP(&secondaryUpstream, "secondary_upstream", "", "", "Secondary upstream endpoint")
|
||||
routerCmd.Flags().StringSliceVarP(&domains, "domains", "", nil, "List of domain to apply in a split DNS policy")
|
||||
routerCmd.Flags().StringVarP(&logPath, "log", "", "", "Path to log file")
|
||||
routerCmd.Flags().IntVarP(&cacheSize, "cache_size", "", 0, "Enable cache with size items")
|
||||
routerCmd.Flags().StringVarP(&cdUID, "cd", "", "", "Control D resolver uid")
|
||||
routerCmd.Flags().BoolVarP(&cdDev, "dev", "", false, "Use Control D dev resolver/domain")
|
||||
_ = routerCmd.Flags().MarkHidden("dev")
|
||||
routerCmd.Flags().StringVarP(&iface, "iface", "", "", `Update DNS setting for iface, "auto" means the default interface gateway`)
|
||||
|
||||
tmpl := routerCmd.UsageTemplate()
|
||||
tmpl = strings.Replace(tmpl, "{{.UseLine}}", "{{.UseLine}} [platform]", 1)
|
||||
routerCmd.SetUsageTemplate(tmpl)
|
||||
rootCmd.AddCommand(routerCmd)
|
||||
}
|
||||
|
||||
func osArgs(platform string) []string {
|
||||
args := os.Args[2:]
|
||||
n := 0
|
||||
for _, x := range args {
|
||||
if x != platform && x != "auto" {
|
||||
args[n] = x
|
||||
n++
|
||||
}
|
||||
}
|
||||
return args[:n]
|
||||
}
|
||||
5
cmd/ctrld/cli_router_others.go
Normal file
5
cmd/ctrld/cli_router_others.go
Normal file
@@ -0,0 +1,5 @@
|
||||
//go:build !linux && !freebsd
|
||||
|
||||
package main
|
||||
|
||||
func initRouterCLI() {}
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
@@ -17,9 +18,22 @@ import (
|
||||
"github.com/Control-D-Inc/ctrld"
|
||||
"github.com/Control-D-Inc/ctrld/internal/dnscache"
|
||||
ctrldnet "github.com/Control-D-Inc/ctrld/internal/net"
|
||||
"github.com/Control-D-Inc/ctrld/internal/router"
|
||||
)
|
||||
|
||||
const staleTTL = 60 * time.Second
|
||||
const (
|
||||
staleTTL = 60 * time.Second
|
||||
// EDNS0_OPTION_MAC is dnsmasq EDNS0 code for adding mac option.
|
||||
// https://thekelleys.org.uk/gitweb/?p=dnsmasq.git;a=blob;f=src/dns-protocol.h;h=76ac66a8c28317e9c121a74ab5fd0e20f6237dc8;hb=HEAD#l81
|
||||
// This is also dns.EDNS0LOCALSTART, but define our own constant here for clarification.
|
||||
EDNS0_OPTION_MAC = 0xFDE9
|
||||
)
|
||||
|
||||
var osUpstreamConfig = &ctrld.UpstreamConfig{
|
||||
Name: "OS resolver",
|
||||
Type: ctrld.ResolverTypeOS,
|
||||
Timeout: 2000,
|
||||
}
|
||||
|
||||
func (p *prog) serveDNS(listenerNum string) error {
|
||||
listenerConfig := p.cfg.Listener[listenerNum]
|
||||
@@ -36,11 +50,12 @@ func (p *prog) serveDNS(listenerNum string) error {
|
||||
q := m.Question[0]
|
||||
domain := canonicalName(q.Name)
|
||||
reqId := requestID()
|
||||
fmtSrcToDest := fmtRemoteToLocal(listenerNum, w.RemoteAddr().String(), w.LocalAddr().String())
|
||||
remoteAddr := spoofRemoteAddr(w.RemoteAddr(), router.GetClientInfoByMac(macFromMsg(m)))
|
||||
fmtSrcToDest := fmtRemoteToLocal(listenerNum, remoteAddr.String(), w.LocalAddr().String())
|
||||
t := time.Now()
|
||||
ctx := context.WithValue(context.Background(), ctrld.ReqIdCtxKey{}, reqId)
|
||||
ctrld.Log(ctx, mainLog.Debug(), "%s received query: %s %s", fmtSrcToDest, dns.TypeToString[q.Qtype], domain)
|
||||
upstreams, matched := p.upstreamFor(ctx, listenerNum, listenerConfig, w.RemoteAddr(), domain)
|
||||
upstreams, matched := p.upstreamFor(ctx, listenerNum, listenerConfig, remoteAddr, domain)
|
||||
var answer *dns.Msg
|
||||
if !matched && listenerConfig.Restricted {
|
||||
answer = new(dns.Msg)
|
||||
@@ -56,40 +71,52 @@ func (p *prog) serveDNS(listenerNum string) error {
|
||||
}
|
||||
})
|
||||
|
||||
g := new(errgroup.Group)
|
||||
g, ctx := errgroup.WithContext(context.Background())
|
||||
for _, proto := range []string{"udp", "tcp"} {
|
||||
proto := proto
|
||||
// On Windows, there's no easy way for disabling/removing IPv6 DNS resolver, so we check whether we can
|
||||
// listen on ::1, then spawn a listener for receiving DNS requests.
|
||||
if runtime.GOOS == "windows" && ctrldnet.SupportsIPv6ListenLocal() {
|
||||
if needLocalIPv6Listener() {
|
||||
g.Go(func() error {
|
||||
s := &dns.Server{
|
||||
Addr: net.JoinHostPort("::1", strconv.Itoa(listenerConfig.Port)),
|
||||
Net: proto,
|
||||
Handler: handler,
|
||||
}
|
||||
if err := s.ListenAndServe(); err != nil {
|
||||
mainLog.Error().Err(err).Msg("could not serving on ::1")
|
||||
s, errCh := runDNSServer(net.JoinHostPort("::1", strconv.Itoa(listenerConfig.Port)), proto, handler)
|
||||
defer s.Shutdown()
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
case err := <-errCh:
|
||||
// Local ipv6 listener should not terminate ctrld.
|
||||
// It's a workaround for a quirk on Windows.
|
||||
mainLog.Warn().Err(err).Msg("local ipv6 listener failed")
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
g.Go(func() error {
|
||||
s := &dns.Server{
|
||||
Addr: net.JoinHostPort(listenerConfig.IP, strconv.Itoa(listenerConfig.Port)),
|
||||
Net: proto,
|
||||
Handler: handler,
|
||||
s, errCh := runDNSServer(dnsListenAddress(listenerNum, listenerConfig), proto, handler)
|
||||
defer s.Shutdown()
|
||||
if listenerConfig.Port == 0 {
|
||||
switch s.Net {
|
||||
case "udp":
|
||||
mainLog.Info().Msgf("Random port chosen for udp listener.%s: %s", listenerNum, s.PacketConn.LocalAddr())
|
||||
case "tcp":
|
||||
mainLog.Info().Msgf("Random port chosen for tcp listener.%s: %s", listenerNum, s.Listener.Addr())
|
||||
}
|
||||
}
|
||||
if err := s.ListenAndServe(); err != nil {
|
||||
mainLog.Error().Err(err).Msgf("could not listen and serve on: %s", s.Addr)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil
|
||||
case err := <-errCh:
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
return g.Wait()
|
||||
}
|
||||
|
||||
// upstreamFor returns the list of upstreams for resolving the given domain,
|
||||
// matching by policies defined in the listener config. The second return value
|
||||
// reports whether the domain matches the policy.
|
||||
//
|
||||
// Though domain policy has higher priority than network policy, it is still
|
||||
// processed later, because policy logging want to know whether a network rule
|
||||
// is disregarded in favor of the domain level rule.
|
||||
func (p *prog) upstreamFor(ctx context.Context, defaultUpstreamNum string, lc *ctrld.ListenerConfig, addr net.Addr, domain string) ([]string, bool) {
|
||||
upstreams := []string{"upstream." + defaultUpstreamNum}
|
||||
matchedPolicy := "no policy"
|
||||
@@ -113,11 +140,43 @@ func (p *prog) upstreamFor(ctx context.Context, defaultUpstreamNum string, lc *c
|
||||
upstreams = append([]string(nil), policyUpstreams...)
|
||||
}
|
||||
|
||||
var networkTargets []string
|
||||
var sourceIP net.IP
|
||||
switch addr := addr.(type) {
|
||||
case *net.UDPAddr:
|
||||
sourceIP = addr.IP
|
||||
case *net.TCPAddr:
|
||||
sourceIP = addr.IP
|
||||
}
|
||||
|
||||
networkRules:
|
||||
for _, rule := range lc.Policy.Networks {
|
||||
for source, targets := range rule {
|
||||
networkNum := strings.TrimPrefix(source, "network.")
|
||||
nc := p.cfg.Network[networkNum]
|
||||
if nc == nil {
|
||||
continue
|
||||
}
|
||||
for _, ipNet := range nc.IPNets {
|
||||
if ipNet.Contains(sourceIP) {
|
||||
matchedPolicy = lc.Policy.Name
|
||||
matchedNetwork = source
|
||||
networkTargets = targets
|
||||
matched = true
|
||||
break networkRules
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, rule := range lc.Policy.Rules {
|
||||
// There's only one entry per rule, config validation ensures this.
|
||||
for source, targets := range rule {
|
||||
if source == domain || wildcardMatches(source, domain) {
|
||||
matchedPolicy = lc.Policy.Name
|
||||
if len(networkTargets) > 0 {
|
||||
matchedNetwork += " (unenforced)"
|
||||
}
|
||||
matchedRule = source
|
||||
do(targets)
|
||||
matched = true
|
||||
@@ -126,31 +185,8 @@ func (p *prog) upstreamFor(ctx context.Context, defaultUpstreamNum string, lc *c
|
||||
}
|
||||
}
|
||||
|
||||
var sourceIP net.IP
|
||||
switch addr := addr.(type) {
|
||||
case *net.UDPAddr:
|
||||
sourceIP = addr.IP
|
||||
case *net.TCPAddr:
|
||||
sourceIP = addr.IP
|
||||
}
|
||||
for _, rule := range lc.Policy.Networks {
|
||||
for source, targets := range rule {
|
||||
networkNum := strings.TrimPrefix(source, "network.")
|
||||
nc := p.cfg.Network[networkNum]
|
||||
if nc == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, ipNet := range nc.IPNets {
|
||||
if ipNet.Contains(sourceIP) {
|
||||
matchedPolicy = lc.Policy.Name
|
||||
matchedNetwork = source
|
||||
do(targets)
|
||||
matched = true
|
||||
return upstreams, matched
|
||||
}
|
||||
}
|
||||
}
|
||||
if matched {
|
||||
do(networkTargets)
|
||||
}
|
||||
|
||||
return upstreams, matched
|
||||
@@ -199,8 +235,16 @@ func (p *prog) proxy(ctx context.Context, upstreams []string, failoverRcodes []i
|
||||
return dnsResolver.Resolve(resolveCtx, msg)
|
||||
}
|
||||
resolve := func(n int, upstreamConfig *ctrld.UpstreamConfig, msg *dns.Msg) *dns.Msg {
|
||||
if upstreamConfig.UpstreamSendClientInfo() {
|
||||
ci := router.GetClientInfoByMac(macFromMsg(msg))
|
||||
if ci != nil {
|
||||
ctrld.Log(ctx, mainLog.Debug(), "including client info with the request")
|
||||
ctx = context.WithValue(ctx, ctrld.ClientInfoCtxKey{}, ci)
|
||||
}
|
||||
}
|
||||
answer, err := resolve1(n, upstreamConfig, msg)
|
||||
if err != nil {
|
||||
// Only do re-bootstrapping if bootstrap ip is not explicitly set by user.
|
||||
if err != nil && upstreamConfig.BootstrapIP == "" {
|
||||
ctrld.Log(ctx, mainLog.Debug().Err(err), "could not resolve query on first attempt, retrying...")
|
||||
// If any error occurred, re-bootstrap transport/ip, retry the request.
|
||||
upstreamConfig.ReBootstrap()
|
||||
@@ -214,6 +258,9 @@ func (p *prog) proxy(ctx context.Context, upstreams []string, failoverRcodes []i
|
||||
return answer
|
||||
}
|
||||
for n, upstreamConfig := range upstreamConfigs {
|
||||
if upstreamConfig == nil {
|
||||
continue
|
||||
}
|
||||
answer := resolve(n, upstreamConfig, msg)
|
||||
if answer == nil {
|
||||
if serveStaleCache && staleAnswer != nil {
|
||||
@@ -228,6 +275,10 @@ func (p *prog) proxy(ctx context.Context, upstreams []string, failoverRcodes []i
|
||||
ctrld.Log(ctx, mainLog.Debug(), "failover rcode matched, process to next upstream")
|
||||
continue
|
||||
}
|
||||
|
||||
// set compression, as it is not set by default when unpacking
|
||||
answer.Compress = true
|
||||
|
||||
if p.cache != nil {
|
||||
ttl := ttlFromMsg(answer)
|
||||
now := time.Now()
|
||||
@@ -341,8 +392,128 @@ func ttlFromMsg(msg *dns.Msg) uint32 {
|
||||
return 0
|
||||
}
|
||||
|
||||
var osUpstreamConfig = &ctrld.UpstreamConfig{
|
||||
Name: "OS resolver",
|
||||
Type: ctrld.ResolverTypeOS,
|
||||
Timeout: 2000,
|
||||
func needLocalIPv6Listener() bool {
|
||||
// On Windows, there's no easy way for disabling/removing IPv6 DNS resolver, so we check whether we can
|
||||
// listen on ::1, then spawn a listener for receiving DNS requests.
|
||||
return ctrldnet.SupportsIPv6ListenLocal() && runtime.GOOS == "windows"
|
||||
}
|
||||
|
||||
func dnsListenAddress(lcNum string, lc *ctrld.ListenerConfig) string {
|
||||
if addr := router.ListenAddress(); setupRouter && addr != "" && lcNum == "0" {
|
||||
return addr
|
||||
}
|
||||
return net.JoinHostPort(lc.IP, strconv.Itoa(lc.Port))
|
||||
}
|
||||
|
||||
func macFromMsg(msg *dns.Msg) string {
|
||||
if opt := msg.IsEdns0(); opt != nil {
|
||||
for _, s := range opt.Option {
|
||||
switch e := s.(type) {
|
||||
case *dns.EDNS0_LOCAL:
|
||||
if e.Code == EDNS0_OPTION_MAC {
|
||||
return net.HardwareAddr(e.Data).String()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func spoofRemoteAddr(addr net.Addr, ci *ctrld.ClientInfo) net.Addr {
|
||||
if ci != nil && ci.IP != "" {
|
||||
switch addr := addr.(type) {
|
||||
case *net.UDPAddr:
|
||||
udpAddr := &net.UDPAddr{
|
||||
IP: net.ParseIP(ci.IP),
|
||||
Port: addr.Port,
|
||||
Zone: addr.Zone,
|
||||
}
|
||||
return udpAddr
|
||||
case *net.TCPAddr:
|
||||
udpAddr := &net.TCPAddr{
|
||||
IP: net.ParseIP(ci.IP),
|
||||
Port: addr.Port,
|
||||
Zone: addr.Zone,
|
||||
}
|
||||
return udpAddr
|
||||
}
|
||||
}
|
||||
return addr
|
||||
}
|
||||
|
||||
// runDNSServer starts a DNS server for given address and network,
|
||||
// with the given handler. It ensures the server has started listening.
|
||||
// Any error will be reported to the caller via returned channel.
|
||||
//
|
||||
// It's the caller responsibility to call Shutdown to close the server.
|
||||
func runDNSServer(addr, network string, handler dns.Handler) (*dns.Server, <-chan error) {
|
||||
s := &dns.Server{
|
||||
Addr: addr,
|
||||
Net: network,
|
||||
Handler: handler,
|
||||
}
|
||||
|
||||
waitLock := sync.Mutex{}
|
||||
waitLock.Lock()
|
||||
s.NotifyStartedFunc = waitLock.Unlock
|
||||
|
||||
errCh := make(chan error)
|
||||
go func() {
|
||||
defer close(errCh)
|
||||
if err := s.ListenAndServe(); err != nil {
|
||||
waitLock.Unlock()
|
||||
mainLog.Error().Err(err).Msgf("could not listen and serve on: %s", s.Addr)
|
||||
errCh <- err
|
||||
}
|
||||
}()
|
||||
waitLock.Lock()
|
||||
return s, errCh
|
||||
}
|
||||
|
||||
// runDNSServerForNTPD starts a DNS server listening on router.ListenAddress(). It must only be called when ctrld
|
||||
// running on router, before router.PreRun() to serve DNS request for NTP synchronization. The caller must call
|
||||
// s.Shutdown() explicitly when NTP is synced successfully.
|
||||
func runDNSServerForNTPD(addr string) (*dns.Server, <-chan error) {
|
||||
if addr == "" {
|
||||
return &dns.Server{}, nil
|
||||
}
|
||||
dnsResolver := ctrld.NewBootstrapResolver()
|
||||
s := &dns.Server{
|
||||
Addr: addr,
|
||||
Net: "udp",
|
||||
Handler: dns.HandlerFunc(func(w dns.ResponseWriter, m *dns.Msg) {
|
||||
mainLog.Debug().Msg("Serving query for ntpd")
|
||||
resolveCtx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
if osUpstreamConfig.Timeout > 0 {
|
||||
timeoutCtx, cancel := context.WithTimeout(resolveCtx, time.Millisecond*time.Duration(osUpstreamConfig.Timeout))
|
||||
defer cancel()
|
||||
resolveCtx = timeoutCtx
|
||||
}
|
||||
answer, err := dnsResolver.Resolve(resolveCtx, m)
|
||||
if err != nil {
|
||||
mainLog.Error().Err(err).Msgf("could not resolve: %v", m)
|
||||
return
|
||||
}
|
||||
if err := w.WriteMsg(answer); err != nil {
|
||||
mainLog.Error().Err(err).Msg("runDNSServerForNTPD: failed to send DNS response")
|
||||
}
|
||||
}),
|
||||
}
|
||||
|
||||
waitLock := sync.Mutex{}
|
||||
waitLock.Lock()
|
||||
s.NotifyStartedFunc = waitLock.Unlock
|
||||
|
||||
errCh := make(chan error)
|
||||
go func() {
|
||||
defer close(errCh)
|
||||
if err := s.ListenAndServe(); err != nil {
|
||||
waitLock.Unlock()
|
||||
mainLog.Error().Err(err).Msgf("could not listen and serve on: %s", s.Addr)
|
||||
errCh <- err
|
||||
}
|
||||
}()
|
||||
waitLock.Lock()
|
||||
return s, errCh
|
||||
}
|
||||
|
||||
@@ -86,17 +86,17 @@ func Test_prog_upstreamFor(t *testing.T) {
|
||||
domain string
|
||||
upstreams []string
|
||||
matched bool
|
||||
testLogMsg string
|
||||
}{
|
||||
{"Policy map matches", "192.168.0.1:0", "0", prog.cfg.Listener["0"], "abc.xyz", []string{"upstream.1", "upstream.0"}, true},
|
||||
{"Policy split matches", "192.168.0.1:0", "0", prog.cfg.Listener["0"], "abc.ru", []string{"upstream.1"}, true},
|
||||
{"Policy map for other network matches", "192.168.1.2:0", "0", prog.cfg.Listener["0"], "abc.xyz", []string{"upstream.0"}, true},
|
||||
{"No policy map for listener", "192.168.1.2:0", "1", prog.cfg.Listener["1"], "abc.ru", []string{"upstream.1"}, false},
|
||||
{"Policy map matches", "192.168.0.1:0", "0", prog.cfg.Listener["0"], "abc.xyz", []string{"upstream.1", "upstream.0"}, true, ""},
|
||||
{"Policy split matches", "192.168.0.1:0", "0", prog.cfg.Listener["0"], "abc.ru", []string{"upstream.1"}, true, ""},
|
||||
{"Policy map for other network matches", "192.168.1.2:0", "0", prog.cfg.Listener["0"], "abc.xyz", []string{"upstream.0"}, true, ""},
|
||||
{"No policy map for listener", "192.168.1.2:0", "1", prog.cfg.Listener["1"], "abc.ru", []string{"upstream.1"}, false, ""},
|
||||
{"unenforced loging", "192.168.1.2:0", "0", prog.cfg.Listener["0"], "abc.ru", []string{"upstream.1"}, true, "My Policy, network.1 (unenforced), *.ru -> [upstream.1]"},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
for _, network := range []string{"udp", "tcp"} {
|
||||
var (
|
||||
addr net.Addr
|
||||
@@ -114,6 +114,9 @@ func Test_prog_upstreamFor(t *testing.T) {
|
||||
upstreams, matched := prog.upstreamFor(ctx, tc.defaultUpstreamNum, tc.lc, addr, tc.domain)
|
||||
assert.Equal(t, tc.matched, matched)
|
||||
assert.Equal(t, tc.upstreams, upstreams)
|
||||
if tc.testLogMsg != "" {
|
||||
assert.Contains(t, logOutput.String(), tc.testLogMsg)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -152,3 +155,64 @@ func TestCache(t *testing.T) {
|
||||
assert.Equal(t, answer1.Rcode, got1.Rcode)
|
||||
assert.Equal(t, answer2.Rcode, got2.Rcode)
|
||||
}
|
||||
|
||||
func Test_macFromMsg(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
mac string
|
||||
wantMac bool
|
||||
}{
|
||||
{"has mac", "4c:20:b8:ab:87:1b", true},
|
||||
{"no mac", "4c:20:b8:ab:87:1b", false},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
hw, err := net.ParseMAC(tc.mac)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
m := new(dns.Msg)
|
||||
m.SetQuestion("example.com.", dns.TypeA)
|
||||
o := &dns.OPT{Hdr: dns.RR_Header{Name: ".", Rrtype: dns.TypeOPT}}
|
||||
if tc.wantMac {
|
||||
ec1 := &dns.EDNS0_LOCAL{Code: EDNS0_OPTION_MAC, Data: hw}
|
||||
o.Option = append(o.Option, ec1)
|
||||
}
|
||||
m.Extra = append(m.Extra, o)
|
||||
got := macFromMsg(m)
|
||||
if tc.wantMac && got != tc.mac {
|
||||
t.Errorf("mismatch, want: %q, got: %q", tc.mac, got)
|
||||
}
|
||||
if !tc.wantMac && got != "" {
|
||||
t.Errorf("unexpected mac: %q", got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_remoteAddrFromMsg(t *testing.T) {
|
||||
loopbackIP := net.ParseIP("127.0.0.1")
|
||||
tests := []struct {
|
||||
name string
|
||||
addr net.Addr
|
||||
ci *ctrld.ClientInfo
|
||||
want string
|
||||
}{
|
||||
{"tcp", &net.TCPAddr{IP: loopbackIP, Port: 12345}, &ctrld.ClientInfo{IP: "192.168.1.10"}, "192.168.1.10:12345"},
|
||||
{"udp", &net.UDPAddr{IP: loopbackIP, Port: 12345}, &ctrld.ClientInfo{IP: "192.168.1.11"}, "192.168.1.11:12345"},
|
||||
{"nil client info", &net.UDPAddr{IP: loopbackIP, Port: 12345}, nil, "127.0.0.1:12345"},
|
||||
{"empty ip", &net.UDPAddr{IP: loopbackIP, Port: 12345}, &ctrld.ClientInfo{}, "127.0.0.1:12345"},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
addr := spoofRemoteAddr(tc.addr, tc.ci)
|
||||
if addr.String() != tc.want {
|
||||
t.Errorf("unexpected result, want: %q, got: %q", tc.want, addr.String())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,7 +2,6 @@ package main
|
||||
|
||||
import (
|
||||
"io"
|
||||
"log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"time"
|
||||
@@ -26,18 +25,25 @@ var (
|
||||
cacheSize int
|
||||
cfg ctrld.Config
|
||||
verbose int
|
||||
silent bool
|
||||
cdUID string
|
||||
cdDev bool
|
||||
iface string
|
||||
ifaceStartStop string
|
||||
setupRouter bool
|
||||
|
||||
rootLogger = zerolog.New(io.Discard)
|
||||
mainLog = rootLogger
|
||||
|
||||
cdUID string
|
||||
iface string
|
||||
ifaceStartStop string
|
||||
mainLog = zerolog.New(io.Discard)
|
||||
consoleWriter zerolog.ConsoleWriter
|
||||
)
|
||||
|
||||
func main() {
|
||||
ctrld.InitConfig(v, "ctrld")
|
||||
initCLI()
|
||||
initRouterCLI()
|
||||
if err := rootCmd.Execute(); err != nil {
|
||||
mainLog.Error().Msg(err.Error())
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
func normalizeLogFilePath(logFilePath string) string {
|
||||
@@ -47,45 +53,65 @@ func normalizeLogFilePath(logFilePath string) string {
|
||||
if homedir != "" {
|
||||
return filepath.Join(homedir, logFilePath)
|
||||
}
|
||||
dir, _ := os.UserHomeDir()
|
||||
dir, _ := userHomeDir()
|
||||
if dir == "" {
|
||||
return logFilePath
|
||||
}
|
||||
return filepath.Join(dir, logFilePath)
|
||||
}
|
||||
|
||||
func initConsoleLogging() {
|
||||
consoleWriter = zerolog.NewConsoleWriter(func(w *zerolog.ConsoleWriter) {
|
||||
w.TimeFormat = time.StampMilli
|
||||
})
|
||||
multi := zerolog.MultiLevelWriter(consoleWriter)
|
||||
mainLog = mainLog.Output(multi).With().Timestamp().Logger()
|
||||
switch {
|
||||
case silent:
|
||||
zerolog.SetGlobalLevel(zerolog.NoLevel)
|
||||
case verbose == 1:
|
||||
zerolog.SetGlobalLevel(zerolog.InfoLevel)
|
||||
case verbose > 1:
|
||||
zerolog.SetGlobalLevel(zerolog.DebugLevel)
|
||||
default:
|
||||
zerolog.SetGlobalLevel(zerolog.NoticeLevel)
|
||||
}
|
||||
}
|
||||
|
||||
func initLogging() {
|
||||
writers := []io.Writer{io.Discard}
|
||||
if logFilePath := normalizeLogFilePath(cfg.Service.LogPath); logFilePath != "" {
|
||||
// Create parent directory if necessary.
|
||||
if err := os.MkdirAll(filepath.Dir(logFilePath), 0750); err != nil {
|
||||
log.Printf("failed to create log path: %v", err)
|
||||
mainLog.Error().Msgf("failed to create log path: %v", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
// Backup old log file with .1 suffix.
|
||||
if err := os.Rename(logFilePath, logFilePath+".1"); err != nil && !os.IsNotExist(err) {
|
||||
log.Printf("could not backup old log file: %v", err)
|
||||
mainLog.Error().Msgf("could not backup old log file: %v", err)
|
||||
}
|
||||
logFile, err := os.OpenFile(logFilePath, os.O_CREATE|os.O_RDWR, os.FileMode(0o600))
|
||||
if err != nil {
|
||||
log.Printf("failed to create log file: %v", err)
|
||||
mainLog.Error().Msgf("failed to create log file: %v", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
writers = append(writers, logFile)
|
||||
}
|
||||
zerolog.TimeFieldFormat = zerolog.TimeFormatUnixMs
|
||||
consoleWriter := zerolog.NewConsoleWriter(func(w *zerolog.ConsoleWriter) {
|
||||
w.TimeFormat = time.StampMilli
|
||||
})
|
||||
writers = append(writers, consoleWriter)
|
||||
multi := zerolog.MultiLevelWriter(writers...)
|
||||
mainLog = mainLog.Output(multi).With().Timestamp().Logger()
|
||||
// TODO: find a better way.
|
||||
ctrld.ProxyLog = mainLog
|
||||
|
||||
zerolog.SetGlobalLevel(zerolog.InfoLevel)
|
||||
zerolog.SetGlobalLevel(zerolog.NoticeLevel)
|
||||
logLevel := cfg.Service.LogLevel
|
||||
if verbose > 1 {
|
||||
switch {
|
||||
case silent:
|
||||
zerolog.SetGlobalLevel(zerolog.NoLevel)
|
||||
return
|
||||
case verbose == 1:
|
||||
logLevel = "info"
|
||||
case verbose > 1:
|
||||
logLevel = "debug"
|
||||
}
|
||||
if logLevel == "" {
|
||||
|
||||
16
cmd/ctrld/main_test.go
Normal file
16
cmd/ctrld/main_test.go
Normal file
@@ -0,0 +1,16 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
)
|
||||
|
||||
var logOutput strings.Builder
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
mainLog = zerolog.New(&logOutput)
|
||||
os.Exit(m.Run())
|
||||
}
|
||||
27
cmd/ctrld/netlink_linux.go
Normal file
27
cmd/ctrld/netlink_linux.go
Normal file
@@ -0,0 +1,27 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"github.com/vishvananda/netlink"
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
func (p *prog) watchLinkState() {
|
||||
ch := make(chan netlink.LinkUpdate)
|
||||
done := make(chan struct{})
|
||||
defer close(done)
|
||||
if err := netlink.LinkSubscribe(ch, done); err != nil {
|
||||
mainLog.Warn().Err(err).Msg("could not subscribe link")
|
||||
return
|
||||
}
|
||||
for lu := range ch {
|
||||
if lu.Change == 0xFFFFFFFF {
|
||||
continue
|
||||
}
|
||||
if lu.Change&unix.IFF_UP != 0 {
|
||||
mainLog.Debug().Msgf("link state changed, re-bootstrapping")
|
||||
for _, uc := range p.cfg.Upstream {
|
||||
uc.ReBootstrap()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
5
cmd/ctrld/netlink_others.go
Normal file
5
cmd/ctrld/netlink_others.go
Normal file
@@ -0,0 +1,5 @@
|
||||
//go:build !linux
|
||||
|
||||
package main
|
||||
|
||||
func (p *prog) watchLinkState() {}
|
||||
@@ -112,7 +112,7 @@ func resetDNS(iface *net.Interface) (err error) {
|
||||
}
|
||||
|
||||
// TODO(cuonglm): handle DHCPv6 properly.
|
||||
if ctrldnet.SupportsIPv6() {
|
||||
if ctrldnet.IPv6Available(ctx) {
|
||||
c := client6.NewClient()
|
||||
conversation, err := c.Exchange(iface.Name)
|
||||
if err != nil {
|
||||
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
|
||||
"github.com/Control-D-Inc/ctrld"
|
||||
"github.com/Control-D-Inc/ctrld/internal/dnscache"
|
||||
"github.com/Control-D-Inc/ctrld/internal/router"
|
||||
)
|
||||
|
||||
var logf = func(format string, args ...any) {
|
||||
@@ -25,9 +26,14 @@ var errWindowsAddrInUse = syscall.Errno(0x2740)
|
||||
var svcConfig = &service.Config{
|
||||
Name: "ctrld",
|
||||
DisplayName: "Control-D Helper Service",
|
||||
Option: service.KeyValue{},
|
||||
}
|
||||
|
||||
type prog struct {
|
||||
mu sync.Mutex
|
||||
waitCh chan struct{}
|
||||
stopCh chan struct{}
|
||||
|
||||
cfg *ctrld.Config
|
||||
cache dnscache.Cacher
|
||||
}
|
||||
@@ -39,6 +45,8 @@ func (p *prog) Start(s service.Service) error {
|
||||
}
|
||||
|
||||
func (p *prog) run() {
|
||||
// Wait the caller to signal that we can do our logic.
|
||||
<-p.waitCh
|
||||
p.preRun()
|
||||
if p.cfg.Service.CacheEnable {
|
||||
cacher, err := dnscache.NewLRUCache(p.cfg.Service.CacheSize)
|
||||
@@ -66,13 +74,16 @@ func (p *prog) run() {
|
||||
uc.Init()
|
||||
if uc.BootstrapIP == "" {
|
||||
uc.SetupBootstrapIP()
|
||||
mainLog.Info().Str("bootstrap_ip", uc.BootstrapIP).Msgf("Setting bootstrap IP for upstream.%s", n)
|
||||
mainLog.Info().Msgf("Bootstrap IPs for upstream.%s: %q", n, uc.BootstrapIPs())
|
||||
} else {
|
||||
mainLog.Info().Str("bootstrap_ip", uc.BootstrapIP).Msgf("Using bootstrap IP for upstream.%s", n)
|
||||
}
|
||||
uc.SetCertPool(rootCertPool)
|
||||
uc.SetupTransport()
|
||||
}
|
||||
|
||||
go p.watchLinkState()
|
||||
|
||||
for listenerNum := range p.cfg.Listener {
|
||||
p.cfg.Listener[listenerNum].Init()
|
||||
go func(listenerNum string) {
|
||||
@@ -80,8 +91,7 @@ func (p *prog) run() {
|
||||
listenerConfig := p.cfg.Listener[listenerNum]
|
||||
upstreamConfig := p.cfg.Upstream[listenerNum]
|
||||
if upstreamConfig == nil {
|
||||
mainLog.Error().Msgf("missing upstream config for: [listener.%s]", listenerNum)
|
||||
return
|
||||
mainLog.Warn().Msgf("no default upstream for: [listener.%s]", listenerNum)
|
||||
}
|
||||
addr := net.JoinHostPort(listenerConfig.IP, strconv.Itoa(listenerConfig.Port))
|
||||
mainLog.Info().Msgf("Starting DNS server on listener.%s: %s", listenerNum, addr)
|
||||
@@ -106,7 +116,9 @@ func (p *prog) run() {
|
||||
} else {
|
||||
mainLog.Info().Msg("writing config file to: " + defaultConfigFile)
|
||||
}
|
||||
p.mu.Lock()
|
||||
p.cfg.Service.AllocateIP = true
|
||||
p.mu.Unlock()
|
||||
p.preRun()
|
||||
mainLog.Info().Msgf("Starting DNS server on listener.%s: %s", listenerNum, net.JoinHostPort(ip, strconv.Itoa(port)))
|
||||
if err := p.serveDNS(listenerNum); err != nil {
|
||||
@@ -127,11 +139,18 @@ func (p *prog) Stop(s service.Service) error {
|
||||
mainLog.Error().Err(err).Msg("de-allocate ip failed")
|
||||
return err
|
||||
}
|
||||
p.preStop()
|
||||
if err := router.Stop(); err != nil {
|
||||
mainLog.Warn().Err(err).Msg("problem occurred while stopping router")
|
||||
}
|
||||
mainLog.Info().Msg("Service stopped")
|
||||
close(p.stopCh)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *prog) allocateIP(ip string) error {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
if !p.cfg.Service.AllocateIP {
|
||||
return nil
|
||||
}
|
||||
@@ -139,6 +158,8 @@ func (p *prog) allocateIP(ip string) error {
|
||||
}
|
||||
|
||||
func (p *prog) deAllocateIP() error {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
if !p.cfg.Service.AllocateIP {
|
||||
return nil
|
||||
}
|
||||
@@ -151,6 +172,15 @@ func (p *prog) deAllocateIP() error {
|
||||
}
|
||||
|
||||
func (p *prog) setDNS() {
|
||||
switch router.Name() {
|
||||
case router.DDWrt, router.OpenWrt, router.Ubios:
|
||||
// On router, ctrld run as a DNS forwarder, it does not have to change system DNS.
|
||||
// Except for:
|
||||
// + EdgeOS, which /etc/resolv.conf could be managed by vyatta_update_resolv.pl script.
|
||||
// + Merlin/Tomato, which has WAN DNS setup on boot for NTP.
|
||||
// + Synology, which /etc/resolv.conf is not configured to point to localhost.
|
||||
return
|
||||
}
|
||||
if cfg.Listener == nil || cfg.Listener["0"] == nil {
|
||||
return
|
||||
}
|
||||
@@ -179,6 +209,11 @@ func (p *prog) setDNS() {
|
||||
}
|
||||
|
||||
func (p *prog) resetDNS() {
|
||||
switch router.Name() {
|
||||
case router.DDWrt, router.OpenWrt, router.Ubios:
|
||||
// See comment in p.setDNS method.
|
||||
return
|
||||
}
|
||||
if iface == "" {
|
||||
return
|
||||
}
|
||||
|
||||
23
cmd/ctrld/prog_darwin.go
Normal file
23
cmd/ctrld/prog_darwin.go
Normal file
@@ -0,0 +1,23 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"github.com/kardianos/service"
|
||||
)
|
||||
|
||||
func (p *prog) preRun() {
|
||||
if !service.Interactive() {
|
||||
p.setDNS()
|
||||
}
|
||||
}
|
||||
|
||||
func setDependencies(svc *service.Config) {}
|
||||
|
||||
func setWorkingDirectory(svc *service.Config, dir string) {
|
||||
svc.WorkingDirectory = dir
|
||||
}
|
||||
|
||||
func (p *prog) preStop() {
|
||||
if !service.Interactive() {
|
||||
p.resetDNS()
|
||||
}
|
||||
}
|
||||
@@ -18,3 +18,5 @@ func setDependencies(svc *service.Config) {
|
||||
}
|
||||
|
||||
func setWorkingDirectory(svc *service.Config, dir string) {}
|
||||
|
||||
func (p *prog) preStop() {}
|
||||
|
||||
@@ -2,6 +2,8 @@ package main
|
||||
|
||||
import (
|
||||
"github.com/kardianos/service"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld/internal/router"
|
||||
)
|
||||
|
||||
func (p *prog) preRun() {
|
||||
@@ -17,8 +19,17 @@ func setDependencies(svc *service.Config) {
|
||||
"Wants=NetworkManager-wait-online.service",
|
||||
"After=NetworkManager-wait-online.service",
|
||||
}
|
||||
// On EdeOS, ctrld needs to start after vyatta-dhcpd, so it can read leases file.
|
||||
if router.Name() == router.EdgeOS {
|
||||
svc.Dependencies = append(svc.Dependencies, "Wants=vyatta-dhcpd.service")
|
||||
svc.Dependencies = append(svc.Dependencies, "After=vyatta-dhcpd.service")
|
||||
svc.Dependencies = append(svc.Dependencies, "Wants=dnsmasq.service")
|
||||
svc.Dependencies = append(svc.Dependencies, "After=dnsmasq.service")
|
||||
}
|
||||
}
|
||||
|
||||
func setWorkingDirectory(svc *service.Config, dir string) {
|
||||
svc.WorkingDirectory = dir
|
||||
}
|
||||
|
||||
func (p *prog) preStop() {}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
//go:build !linux && !freebsd
|
||||
//go:build !linux && !freebsd && !darwin
|
||||
|
||||
package main
|
||||
|
||||
@@ -12,3 +12,5 @@ func setWorkingDirectory(svc *service.Config, dir string) {
|
||||
// WorkingDirectory is not supported on Windows.
|
||||
svc.WorkingDirectory = dir
|
||||
}
|
||||
|
||||
func (p *prog) preStop() {}
|
||||
|
||||
@@ -1,18 +1,45 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"bytes"
|
||||
"errors"
|
||||
"os"
|
||||
"os/exec"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
"github.com/kardianos/service"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld/internal/router"
|
||||
)
|
||||
|
||||
func stderrMsg(msg string) {
|
||||
_, _ = fmt.Fprintln(os.Stderr, msg)
|
||||
func newService(s service.Service) service.Service {
|
||||
// TODO: unify for other SysV system.
|
||||
switch {
|
||||
case router.IsGLiNet(), router.IsOldOpenwrt():
|
||||
return &sysV{s}
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
func stdoutMsg(msg string) {
|
||||
_, _ = fmt.Fprintln(os.Stdout, msg)
|
||||
// sysV wraps a service.Service, and provide start/stop/status command
|
||||
// base on "/etc/init.d/<service_name>".
|
||||
//
|
||||
// Use this on system wherer "service" command is not available, like GL.iNET router.
|
||||
type sysV struct {
|
||||
service.Service
|
||||
}
|
||||
|
||||
func (s *sysV) Start() error {
|
||||
_, err := exec.Command("/etc/init.d/ctrld", "start").CombinedOutput()
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *sysV) Stop() error {
|
||||
_, err := exec.Command("/etc/init.d/ctrld", "stop").CombinedOutput()
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *sysV) Status() (service.Status, error) {
|
||||
return unixSystemVServiceStatus()
|
||||
}
|
||||
|
||||
type task struct {
|
||||
@@ -21,25 +48,48 @@ type task struct {
|
||||
}
|
||||
|
||||
func doTasks(tasks []task) bool {
|
||||
var prevErr error
|
||||
for _, task := range tasks {
|
||||
if err := task.f(); err != nil {
|
||||
if task.abortOnError {
|
||||
stderrMsg(err.Error())
|
||||
mainLog.Error().Msg(errors.Join(prevErr, err).Error())
|
||||
return false
|
||||
}
|
||||
prevErr = err
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func checkHasElevatedPrivilege(cmd *cobra.Command, args []string) {
|
||||
func checkHasElevatedPrivilege() {
|
||||
ok, err := hasElevatedPrivilege()
|
||||
if err != nil {
|
||||
fmt.Printf("could not detect user privilege: %v", err)
|
||||
mainLog.Error().Msgf("could not detect user privilege: %v", err)
|
||||
return
|
||||
}
|
||||
if !ok {
|
||||
fmt.Println("Please relaunch process with admin/root privilege.")
|
||||
mainLog.Error().Msg("Please relaunch process with admin/root privilege.")
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
func serviceStatus(s service.Service) (service.Status, error) {
|
||||
status, err := s.Status()
|
||||
if err != nil && service.Platform() == "unix-systemv" {
|
||||
return unixSystemVServiceStatus()
|
||||
}
|
||||
return status, err
|
||||
}
|
||||
|
||||
func unixSystemVServiceStatus() (service.Status, error) {
|
||||
out, err := exec.Command("/etc/init.d/ctrld", "status").CombinedOutput()
|
||||
if err != nil {
|
||||
return service.StatusUnknown, nil
|
||||
}
|
||||
switch string(bytes.TrimSpace(out)) {
|
||||
case "running":
|
||||
return service.StatusRunning, nil
|
||||
default:
|
||||
return service.StatusStopped, nil
|
||||
}
|
||||
}
|
||||
|
||||
467
config.go
467
config.go
@@ -2,41 +2,75 @@ package ctrld
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"errors"
|
||||
"math/rand"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/go-playground/validator/v10"
|
||||
"github.com/miekg/dns"
|
||||
"github.com/spf13/viper"
|
||||
"golang.org/x/sync/singleflight"
|
||||
"tailscale.com/logtail/backoff"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld/internal/dnsrcode"
|
||||
ctrldnet "github.com/Control-D-Inc/ctrld/internal/net"
|
||||
)
|
||||
|
||||
// SetConfigName set the config name that ctrld will look for.
|
||||
func SetConfigName(v *viper.Viper, name string) {
|
||||
v.SetConfigName(name)
|
||||
// IpStackBoth ...
|
||||
const (
|
||||
// IpStackBoth indicates that ctrld will use either ipv4 or ipv6 for connecting to upstream,
|
||||
// depending on which stack is available when receiving the DNS query.
|
||||
IpStackBoth = "both"
|
||||
// IpStackV4 indicates that ctrld will use only ipv4 for connecting to upstream.
|
||||
IpStackV4 = "v4"
|
||||
// IpStackV6 indicates that ctrld will use only ipv6 for connecting to upstream.
|
||||
IpStackV6 = "v6"
|
||||
// IpStackSplit indicates that ctrld will use either ipv4 or ipv6 for connecting to upstream,
|
||||
// depending on the record type of the DNS query.
|
||||
IpStackSplit = "split"
|
||||
|
||||
controlDComDomain = "controld.com"
|
||||
controlDNetDomain = "controld.net"
|
||||
controlDDevDomain = "controld.dev"
|
||||
)
|
||||
|
||||
var (
|
||||
controldParentDomains = []string{controlDComDomain, controlDNetDomain, controlDDevDomain}
|
||||
controldVerifiedDomain = map[string]string{
|
||||
controlDComDomain: "verify.controld.com",
|
||||
controlDDevDomain: "verify.controld.dev",
|
||||
}
|
||||
)
|
||||
|
||||
// SetConfigName set the config name that ctrld will look for.
|
||||
// DEPRECATED: use SetConfigNameWithPath instead.
|
||||
func SetConfigName(v *viper.Viper, name string) {
|
||||
configPath := "$HOME"
|
||||
// viper has its own way to get user home directory: https://github.com/spf13/viper/blob/v1.14.0/util.go#L134
|
||||
// To be consistent, we prefer os.UserHomeDir instead.
|
||||
if homeDir, err := os.UserHomeDir(); err == nil {
|
||||
configPath = homeDir
|
||||
}
|
||||
SetConfigNameWithPath(v, name, configPath)
|
||||
}
|
||||
|
||||
// SetConfigNameWithPath set the config path and name that ctrld will look for.
|
||||
func SetConfigNameWithPath(v *viper.Viper, name, configPath string) {
|
||||
v.SetConfigName(name)
|
||||
v.AddConfigPath(configPath)
|
||||
v.AddConfigPath(".")
|
||||
}
|
||||
|
||||
// InitConfig initializes default config values for given *viper.Viper instance.
|
||||
func InitConfig(v *viper.Viper, name string) {
|
||||
SetConfigName(v, name)
|
||||
|
||||
v.SetDefault("listener", map[string]*ListenerConfig{
|
||||
"0": {
|
||||
IP: "127.0.0.1",
|
||||
@@ -75,6 +109,17 @@ type Config struct {
|
||||
Upstream map[string]*UpstreamConfig `mapstructure:"upstream" toml:"upstream" validate:"min=1,dive"`
|
||||
}
|
||||
|
||||
// HasUpstreamSendClientInfo reports whether the config has any upstream
|
||||
// is configured to send client info to Control D DNS server.
|
||||
func (c *Config) HasUpstreamSendClientInfo() bool {
|
||||
for _, uc := range c.Upstream {
|
||||
if uc.UpstreamSendClientInfo() {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// ServiceConfig specifies the general ctrld config.
|
||||
type ServiceConfig struct {
|
||||
LogLevel string `mapstructure:"log_level" toml:"log_level,omitempty"`
|
||||
@@ -96,24 +141,36 @@ type NetworkConfig struct {
|
||||
|
||||
// UpstreamConfig specifies configuration for upstreams that ctrld will forward requests to.
|
||||
type UpstreamConfig struct {
|
||||
Name string `mapstructure:"name" toml:"name,omitempty"`
|
||||
Type string `mapstructure:"type" toml:"type,omitempty" validate:"oneof=doh doh3 dot doq os legacy"`
|
||||
Endpoint string `mapstructure:"endpoint" toml:"endpoint,omitempty" validate:"required_unless=Type os"`
|
||||
BootstrapIP string `mapstructure:"bootstrap_ip" toml:"bootstrap_ip,omitempty"`
|
||||
Domain string `mapstructure:"-" toml:"-"`
|
||||
Timeout int `mapstructure:"timeout" toml:"timeout,omitempty" validate:"gte=0"`
|
||||
transport *http.Transport `mapstructure:"-" toml:"-"`
|
||||
http3RoundTripper http.RoundTripper `mapstructure:"-" toml:"-"`
|
||||
Name string `mapstructure:"name" toml:"name,omitempty"`
|
||||
Type string `mapstructure:"type" toml:"type,omitempty" validate:"oneof=doh doh3 dot doq os legacy"`
|
||||
Endpoint string `mapstructure:"endpoint" toml:"endpoint,omitempty" validate:"required_unless=Type os"`
|
||||
BootstrapIP string `mapstructure:"bootstrap_ip" toml:"bootstrap_ip,omitempty"`
|
||||
Domain string `mapstructure:"-" toml:"-"`
|
||||
IPStack string `mapstructure:"ip_stack" toml:"ip_stack,omitempty" validate:"ipstack"`
|
||||
Timeout int `mapstructure:"timeout" toml:"timeout,omitempty" validate:"gte=0"`
|
||||
// The caller should not access this field directly.
|
||||
// Use UpstreamSendClientInfo instead.
|
||||
SendClientInfo *bool `mapstructure:"send_client_info" toml:"send_client_info,omitempty"`
|
||||
|
||||
g singleflight.Group
|
||||
bootstrapIPs []string
|
||||
nextBootstrapIP atomic.Uint32
|
||||
g singleflight.Group
|
||||
mu sync.Mutex
|
||||
bootstrapIPs []string
|
||||
bootstrapIPs4 []string
|
||||
bootstrapIPs6 []string
|
||||
transport *http.Transport
|
||||
transport4 *http.Transport
|
||||
transport6 *http.Transport
|
||||
http3RoundTripper http.RoundTripper
|
||||
http3RoundTripper4 http.RoundTripper
|
||||
http3RoundTripper6 http.RoundTripper
|
||||
certPool *x509.CertPool
|
||||
u *url.URL
|
||||
}
|
||||
|
||||
// ListenerConfig specifies the networks configuration that ctrld will run on.
|
||||
type ListenerConfig struct {
|
||||
IP string `mapstructure:"ip" toml:"ip,omitempty" validate:"ip"`
|
||||
Port int `mapstructure:"port" toml:"port,omitempty" validate:"gt=0"`
|
||||
IP string `mapstructure:"ip" toml:"ip,omitempty" validate:"iporempty"`
|
||||
Port int `mapstructure:"port" toml:"port,omitempty" validate:"gte=0"`
|
||||
Restricted bool `mapstructure:"restricted" toml:"restricted,omitempty"`
|
||||
Policy *ListenerPolicyConfig `mapstructure:"policy" toml:"policy,omitempty"`
|
||||
}
|
||||
@@ -136,87 +193,105 @@ type Rule map[string][]string
|
||||
func (uc *UpstreamConfig) Init() {
|
||||
if u, err := url.Parse(uc.Endpoint); err == nil {
|
||||
uc.Domain = u.Host
|
||||
switch uc.Type {
|
||||
case ResolverTypeDOH, ResolverTypeDOH3:
|
||||
uc.u = u
|
||||
}
|
||||
}
|
||||
if uc.Domain != "" {
|
||||
return
|
||||
if uc.Domain == "" {
|
||||
if !strings.Contains(uc.Endpoint, ":") {
|
||||
uc.Domain = uc.Endpoint
|
||||
uc.Endpoint = net.JoinHostPort(uc.Endpoint, defaultPortFor(uc.Type))
|
||||
}
|
||||
host, _, _ := net.SplitHostPort(uc.Endpoint)
|
||||
uc.Domain = host
|
||||
if net.ParseIP(uc.Domain) != nil {
|
||||
uc.BootstrapIP = uc.Domain
|
||||
}
|
||||
}
|
||||
if uc.IPStack == "" {
|
||||
if uc.isControlD() {
|
||||
uc.IPStack = IpStackSplit
|
||||
} else {
|
||||
uc.IPStack = IpStackBoth
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !strings.Contains(uc.Endpoint, ":") {
|
||||
uc.Domain = uc.Endpoint
|
||||
uc.Endpoint = net.JoinHostPort(uc.Endpoint, defaultPortFor(uc.Type))
|
||||
// VerifyDomain returns the domain name that could be resolved by the upstream endpoint.
|
||||
// It returns empty for non-ControlD upstream endpoint.
|
||||
func (uc *UpstreamConfig) VerifyDomain() string {
|
||||
domain := uc.Domain
|
||||
if domain == "" {
|
||||
if u, err := url.Parse(uc.Endpoint); err == nil {
|
||||
domain = u.Hostname()
|
||||
}
|
||||
}
|
||||
host, _, _ := net.SplitHostPort(uc.Endpoint)
|
||||
uc.Domain = host
|
||||
if net.ParseIP(uc.Domain) != nil {
|
||||
uc.BootstrapIP = uc.Domain
|
||||
for _, parent := range controldParentDomains {
|
||||
if dns.IsSubDomain(parent, domain) {
|
||||
return controldVerifiedDomain[parent]
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// UpstreamSendClientInfo reports whether the upstream is
|
||||
// configured to send client info to Control D DNS server.
|
||||
//
|
||||
// Client info includes:
|
||||
// - MAC
|
||||
// - Lan IP
|
||||
// - Hostname
|
||||
func (uc *UpstreamConfig) UpstreamSendClientInfo() bool {
|
||||
if uc.SendClientInfo != nil && !(*uc.SendClientInfo) {
|
||||
return false
|
||||
}
|
||||
if uc.SendClientInfo == nil {
|
||||
return true
|
||||
}
|
||||
switch uc.Type {
|
||||
case ResolverTypeDOH, ResolverTypeDOH3:
|
||||
if uc.isControlD() {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// BootstrapIPs returns the bootstrap IPs list of upstreams.
|
||||
func (uc *UpstreamConfig) BootstrapIPs() []string {
|
||||
return uc.bootstrapIPs
|
||||
}
|
||||
|
||||
// SetCertPool sets the system cert pool used for TLS connections.
|
||||
func (uc *UpstreamConfig) SetCertPool(cp *x509.CertPool) {
|
||||
uc.certPool = cp
|
||||
}
|
||||
|
||||
// SetupBootstrapIP manually find all available IPs of the upstream.
|
||||
// The first usable IP will be used as bootstrap IP of the upstream.
|
||||
func (uc *UpstreamConfig) SetupBootstrapIP() {
|
||||
bootstrapIP := func(record dns.RR) string {
|
||||
switch ar := record.(type) {
|
||||
case *dns.A:
|
||||
return ar.A.String()
|
||||
case *dns.AAAA:
|
||||
return ar.AAAA.String()
|
||||
uc.setupBootstrapIP(true)
|
||||
}
|
||||
|
||||
// SetupBootstrapIP manually find all available IPs of the upstream.
|
||||
// The first usable IP will be used as bootstrap IP of the upstream.
|
||||
func (uc *UpstreamConfig) setupBootstrapIP(withBootstrapDNS bool) {
|
||||
b := backoff.NewBackoff("setupBootstrapIP", func(format string, args ...any) {}, 2*time.Second)
|
||||
for {
|
||||
uc.bootstrapIPs = lookupIP(uc.Domain, uc.Timeout, withBootstrapDNS)
|
||||
if len(uc.bootstrapIPs) > 0 {
|
||||
break
|
||||
}
|
||||
return ""
|
||||
ProxyLog.Warn().Msg("could not resolve bootstrap IPs, retrying...")
|
||||
b.BackOff(context.Background(), errors.New("no bootstrap IPs"))
|
||||
}
|
||||
|
||||
resolver := &osResolver{nameservers: availableNameservers()}
|
||||
resolver.nameservers = append([]string{net.JoinHostPort(bootstrapDNS, "53")}, resolver.nameservers...)
|
||||
ProxyLog.Debug().Msgf("Resolving %q using bootstrap DNS %q", uc.Domain, resolver.nameservers)
|
||||
timeoutMs := 2000
|
||||
if uc.Timeout > 0 && uc.Timeout < timeoutMs {
|
||||
timeoutMs = uc.Timeout
|
||||
}
|
||||
do := func(dnsType uint16) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(timeoutMs)*time.Millisecond)
|
||||
defer cancel()
|
||||
m := new(dns.Msg)
|
||||
m.SetQuestion(uc.Domain+".", dnsType)
|
||||
m.RecursionDesired = true
|
||||
|
||||
r, err := resolver.Resolve(ctx, m)
|
||||
if err != nil {
|
||||
ProxyLog.Error().Err(err).Str("type", dns.TypeToString[dnsType]).Msgf("could not resolve domain %s for upstream", uc.Domain)
|
||||
return
|
||||
for _, ip := range uc.bootstrapIPs {
|
||||
if ctrldnet.IsIPv6(ip) {
|
||||
uc.bootstrapIPs6 = append(uc.bootstrapIPs6, ip)
|
||||
} else {
|
||||
uc.bootstrapIPs4 = append(uc.bootstrapIPs4, ip)
|
||||
}
|
||||
if r.Rcode != dns.RcodeSuccess {
|
||||
ProxyLog.Error().Msgf("could not resolve domain return code: %d, upstream", r.Rcode)
|
||||
return
|
||||
}
|
||||
if len(r.Answer) == 0 {
|
||||
ProxyLog.Error().Msg("no answer from bootstrap DNS server")
|
||||
return
|
||||
}
|
||||
for _, a := range r.Answer {
|
||||
ip := bootstrapIP(a)
|
||||
if ip == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
// Storing the ip to uc.bootstrapIPs list, so it can be selected later
|
||||
// when retrying failed request due to network stack changed.
|
||||
uc.bootstrapIPs = append(uc.bootstrapIPs, ip)
|
||||
if uc.BootstrapIP == "" {
|
||||
// Remember what's the current IP in bootstrap IPs list,
|
||||
// so we can select next one upon re-bootstrapping.
|
||||
uc.nextBootstrapIP.Add(1)
|
||||
|
||||
// If this is an ipv6, and ipv6 is not available, don't use it as bootstrap ip.
|
||||
if !ctrldnet.SupportsIPv6() && ctrldnet.IsIPv6(ip) {
|
||||
continue
|
||||
}
|
||||
uc.BootstrapIP = ip
|
||||
}
|
||||
}
|
||||
}
|
||||
// Find all A, AAAA records of the upstream.
|
||||
for _, dnsType := range []uint16{dns.TypeAAAA, dns.TypeA} {
|
||||
do(dnsType)
|
||||
}
|
||||
ProxyLog.Debug().Msgf("Bootstrap IPs: %v", uc.bootstrapIPs)
|
||||
}
|
||||
@@ -228,30 +303,8 @@ func (uc *UpstreamConfig) ReBootstrap() {
|
||||
default:
|
||||
return
|
||||
}
|
||||
_, _, _ = uc.g.Do("rebootstrap", func() (any, error) {
|
||||
_, _, _ = uc.g.Do("ReBootstrap", func() (any, error) {
|
||||
ProxyLog.Debug().Msg("re-bootstrapping upstream ip")
|
||||
n := uint32(len(uc.bootstrapIPs))
|
||||
|
||||
timeoutMs := 1000
|
||||
if uc.Timeout > 0 && uc.Timeout < timeoutMs {
|
||||
timeoutMs = uc.Timeout
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(timeoutMs)*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
hasIPv6 := ctrldnet.IPv6Available(ctx)
|
||||
// Only attempt n times, because if there's no usable ip,
|
||||
// the bootstrap ip will be kept as-is.
|
||||
for i := uint32(0); i < n; i++ {
|
||||
// Select the next ip in bootstrap ip list.
|
||||
next := uc.nextBootstrapIP.Add(1)
|
||||
ip := uc.bootstrapIPs[(next-1)%n]
|
||||
if !hasIPv6 && ctrldnet.IsIPv6(ip) {
|
||||
continue
|
||||
}
|
||||
uc.BootstrapIP = ip
|
||||
break
|
||||
}
|
||||
uc.setupTransportWithoutPingUpstream()
|
||||
return true, nil
|
||||
})
|
||||
@@ -279,31 +332,63 @@ func (uc *UpstreamConfig) SetupTransport() {
|
||||
|
||||
func (uc *UpstreamConfig) setupDOHTransport() {
|
||||
uc.setupDOHTransportWithoutPingUpstream()
|
||||
uc.pingUpstream()
|
||||
go uc.pingUpstream()
|
||||
}
|
||||
|
||||
func (uc *UpstreamConfig) setupDOHTransportWithoutPingUpstream() {
|
||||
uc.transport = http.DefaultTransport.(*http.Transport).Clone()
|
||||
uc.transport.IdleConnTimeout = 5 * time.Second
|
||||
func (uc *UpstreamConfig) newDOHTransport(addrs []string) *http.Transport {
|
||||
transport := http.DefaultTransport.(*http.Transport).Clone()
|
||||
transport.IdleConnTimeout = 5 * time.Second
|
||||
transport.TLSClientConfig = &tls.Config{RootCAs: uc.certPool}
|
||||
|
||||
dialerTimeoutMs := 2000
|
||||
if uc.Timeout > 0 && uc.Timeout < dialerTimeoutMs {
|
||||
dialerTimeoutMs = uc.Timeout
|
||||
}
|
||||
dialerTimeout := time.Duration(dialerTimeoutMs) * time.Millisecond
|
||||
uc.transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
dialer := &net.Dialer{
|
||||
Timeout: dialerTimeout,
|
||||
KeepAlive: dialerTimeout,
|
||||
}
|
||||
// if we have a bootstrap ip set, use it to avoid DNS lookup
|
||||
transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
_, port, _ := net.SplitHostPort(addr)
|
||||
if uc.BootstrapIP != "" {
|
||||
if _, port, _ := net.SplitHostPort(addr); port != "" {
|
||||
addr = net.JoinHostPort(uc.BootstrapIP, port)
|
||||
}
|
||||
dialer := net.Dialer{Timeout: dialerTimeout, KeepAlive: dialerTimeout}
|
||||
addr := net.JoinHostPort(uc.BootstrapIP, port)
|
||||
Log(ctx, ProxyLog.Debug(), "sending doh request to: %s", addr)
|
||||
return dialer.DialContext(ctx, network, addr)
|
||||
}
|
||||
Log(ctx, ProxyLog.Debug(), "sending doh request to: %s", addr)
|
||||
return dialer.DialContext(ctx, network, addr)
|
||||
pd := &ctrldnet.ParallelDialer{}
|
||||
pd.Timeout = dialerTimeout
|
||||
pd.KeepAlive = dialerTimeout
|
||||
dialAddrs := make([]string, len(addrs))
|
||||
for i := range addrs {
|
||||
dialAddrs[i] = net.JoinHostPort(addrs[i], port)
|
||||
}
|
||||
conn, err := pd.DialContext(ctx, network, dialAddrs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
Log(ctx, ProxyLog.Debug(), "sending doh request to: %s", conn.RemoteAddr())
|
||||
return conn, nil
|
||||
}
|
||||
return transport
|
||||
}
|
||||
|
||||
func (uc *UpstreamConfig) setupDOHTransportWithoutPingUpstream() {
|
||||
uc.mu.Lock()
|
||||
defer uc.mu.Unlock()
|
||||
switch uc.IPStack {
|
||||
case IpStackBoth, "":
|
||||
uc.transport = uc.newDOHTransport(uc.bootstrapIPs)
|
||||
case IpStackV4:
|
||||
uc.transport = uc.newDOHTransport(uc.bootstrapIPs4)
|
||||
case IpStackV6:
|
||||
uc.transport = uc.newDOHTransport(uc.bootstrapIPs6)
|
||||
case IpStackSplit:
|
||||
uc.transport4 = uc.newDOHTransport(uc.bootstrapIPs4)
|
||||
if hasIPv6() {
|
||||
uc.transport6 = uc.newDOHTransport(uc.bootstrapIPs6)
|
||||
} else {
|
||||
uc.transport6 = uc.transport4
|
||||
}
|
||||
|
||||
uc.transport = uc.newDOHTransport(uc.bootstrapIPs)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -322,6 +407,82 @@ func (uc *UpstreamConfig) pingUpstream() {
|
||||
_, _ = dnsResolver.Resolve(ctx, msg)
|
||||
}
|
||||
|
||||
func (uc *UpstreamConfig) isControlD() bool {
|
||||
domain := uc.Domain
|
||||
if domain == "" {
|
||||
if u, err := url.Parse(uc.Endpoint); err == nil {
|
||||
domain = u.Hostname()
|
||||
}
|
||||
}
|
||||
for _, parent := range controldParentDomains {
|
||||
if dns.IsSubDomain(parent, domain) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (uc *UpstreamConfig) dohTransport(dnsType uint16) http.RoundTripper {
|
||||
uc.mu.Lock()
|
||||
defer uc.mu.Unlock()
|
||||
switch uc.IPStack {
|
||||
case IpStackBoth, IpStackV4, IpStackV6:
|
||||
return uc.transport
|
||||
case IpStackSplit:
|
||||
switch dnsType {
|
||||
case dns.TypeA:
|
||||
return uc.transport4
|
||||
default:
|
||||
return uc.transport6
|
||||
}
|
||||
}
|
||||
return uc.transport
|
||||
}
|
||||
|
||||
func (uc *UpstreamConfig) bootstrapIPForDNSType(dnsType uint16) string {
|
||||
switch uc.IPStack {
|
||||
case IpStackBoth:
|
||||
return pick(uc.bootstrapIPs)
|
||||
case IpStackV4:
|
||||
return pick(uc.bootstrapIPs4)
|
||||
case IpStackV6:
|
||||
return pick(uc.bootstrapIPs6)
|
||||
case IpStackSplit:
|
||||
switch dnsType {
|
||||
case dns.TypeA:
|
||||
return pick(uc.bootstrapIPs4)
|
||||
default:
|
||||
if hasIPv6() {
|
||||
return pick(uc.bootstrapIPs6)
|
||||
}
|
||||
return pick(uc.bootstrapIPs4)
|
||||
}
|
||||
}
|
||||
return pick(uc.bootstrapIPs)
|
||||
}
|
||||
|
||||
func (uc *UpstreamConfig) netForDNSType(dnsType uint16) (string, string) {
|
||||
switch uc.IPStack {
|
||||
case IpStackBoth:
|
||||
return "tcp-tls", "udp"
|
||||
case IpStackV4:
|
||||
return "tcp4-tls", "udp4"
|
||||
case IpStackV6:
|
||||
return "tcp6-tls", "udp6"
|
||||
case IpStackSplit:
|
||||
switch dnsType {
|
||||
case dns.TypeA:
|
||||
return "tcp4-tls", "udp4"
|
||||
default:
|
||||
if hasIPv6() {
|
||||
return "tcp6-tls", "udp6"
|
||||
}
|
||||
return "tcp4-tls", "udp4"
|
||||
}
|
||||
}
|
||||
return "tcp-tls", "udp"
|
||||
}
|
||||
|
||||
// Init initialized necessary values for an ListenerConfig.
|
||||
func (lc *ListenerConfig) Init() {
|
||||
if lc.Policy != nil {
|
||||
@@ -335,6 +496,8 @@ func (lc *ListenerConfig) Init() {
|
||||
// ValidateConfig validates the given config.
|
||||
func ValidateConfig(validate *validator.Validate, cfg *Config) error {
|
||||
_ = validate.RegisterValidation("dnsrcode", validateDnsRcode)
|
||||
_ = validate.RegisterValidation("ipstack", validateIpStack)
|
||||
_ = validate.RegisterValidation("iporempty", validateIpOrEmpty)
|
||||
return validate.Struct(cfg)
|
||||
}
|
||||
|
||||
@@ -342,6 +505,23 @@ func validateDnsRcode(fl validator.FieldLevel) bool {
|
||||
return dnsrcode.FromString(fl.Field().String()) != -1
|
||||
}
|
||||
|
||||
func validateIpStack(fl validator.FieldLevel) bool {
|
||||
switch fl.Field().String() {
|
||||
case IpStackBoth, IpStackV4, IpStackV6, IpStackSplit, "":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func validateIpOrEmpty(fl validator.FieldLevel) bool {
|
||||
val := fl.Field().String()
|
||||
if val == "" {
|
||||
return true
|
||||
}
|
||||
return net.ParseIP(val) != nil
|
||||
}
|
||||
|
||||
func defaultPortFor(typ string) string {
|
||||
switch typ {
|
||||
case ResolverTypeDOH, ResolverTypeDOH3:
|
||||
@@ -354,17 +534,30 @@ func defaultPortFor(typ string) string {
|
||||
return "53"
|
||||
}
|
||||
|
||||
func availableNameservers() []string {
|
||||
nss := nameservers()
|
||||
n := 0
|
||||
for _, ns := range nss {
|
||||
ip, _, _ := net.SplitHostPort(ns)
|
||||
// skipping invalid entry or ipv6 nameserver if ipv6 not available.
|
||||
if ip == "" || (ctrldnet.IsIPv6(ip) && !ctrldnet.SupportsIPv6()) {
|
||||
continue
|
||||
}
|
||||
nss[n] = ns
|
||||
n++
|
||||
// ResolverTypeFromEndpoint tries guessing the resolver type with a given endpoint
|
||||
// using following rules:
|
||||
//
|
||||
// - If endpoint is an IP address -> ResolverTypeLegacy
|
||||
// - If endpoint starts with "https://" -> ResolverTypeDOH
|
||||
// - If endpoint starts with "quic://" -> ResolverTypeDOQ
|
||||
// - For anything else -> ResolverTypeDOT
|
||||
func ResolverTypeFromEndpoint(endpoint string) string {
|
||||
switch {
|
||||
case strings.HasPrefix(endpoint, "https://"):
|
||||
return ResolverTypeDOH
|
||||
case strings.HasPrefix(endpoint, "quic://"):
|
||||
return ResolverTypeDOQ
|
||||
}
|
||||
return nss[:n]
|
||||
host := endpoint
|
||||
if strings.Contains(endpoint, ":") {
|
||||
host, _, _ = net.SplitHostPort(host)
|
||||
}
|
||||
if ip := net.ParseIP(host); ip != nil {
|
||||
return ResolverTypeLegacy
|
||||
}
|
||||
return ResolverTypeDOT
|
||||
}
|
||||
|
||||
func pick(s []string) string {
|
||||
return s[rand.Intn(len(s))]
|
||||
}
|
||||
|
||||
228
config_internal_test.go
Normal file
228
config_internal_test.go
Normal file
@@ -0,0 +1,228 @@
|
||||
package ctrld
|
||||
|
||||
import (
|
||||
"net/url"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestUpstreamConfig_SetupBootstrapIP(t *testing.T) {
|
||||
uc := &UpstreamConfig{
|
||||
Name: "test",
|
||||
Type: ResolverTypeDOH,
|
||||
Endpoint: "https://freedns.controld.com/p2",
|
||||
Timeout: 5000,
|
||||
}
|
||||
uc.Init()
|
||||
uc.setupBootstrapIP(false)
|
||||
if len(uc.bootstrapIPs) == 0 {
|
||||
t.Log(nameservers())
|
||||
t.Fatal("could not bootstrap ip without bootstrap DNS")
|
||||
}
|
||||
t.Log(uc)
|
||||
}
|
||||
|
||||
func TestUpstreamConfig_Init(t *testing.T) {
|
||||
u1, _ := url.Parse("https://example.com")
|
||||
u2, _ := url.Parse("https://example.com?k=v")
|
||||
tests := []struct {
|
||||
name string
|
||||
uc *UpstreamConfig
|
||||
expected *UpstreamConfig
|
||||
}{
|
||||
{
|
||||
"doh+doh3",
|
||||
&UpstreamConfig{
|
||||
Name: "doh",
|
||||
Type: "doh",
|
||||
Endpoint: "https://example.com",
|
||||
BootstrapIP: "",
|
||||
Domain: "",
|
||||
Timeout: 0,
|
||||
},
|
||||
&UpstreamConfig{
|
||||
Name: "doh",
|
||||
Type: "doh",
|
||||
Endpoint: "https://example.com",
|
||||
BootstrapIP: "",
|
||||
Domain: "example.com",
|
||||
Timeout: 0,
|
||||
IPStack: IpStackBoth,
|
||||
u: u1,
|
||||
},
|
||||
},
|
||||
{
|
||||
"doh+doh3 with query param",
|
||||
&UpstreamConfig{
|
||||
Name: "doh",
|
||||
Type: "doh",
|
||||
Endpoint: "https://example.com?k=v",
|
||||
BootstrapIP: "",
|
||||
Domain: "",
|
||||
Timeout: 0,
|
||||
},
|
||||
&UpstreamConfig{
|
||||
Name: "doh",
|
||||
Type: "doh",
|
||||
Endpoint: "https://example.com?k=v",
|
||||
BootstrapIP: "",
|
||||
Domain: "example.com",
|
||||
Timeout: 0,
|
||||
IPStack: IpStackBoth,
|
||||
u: u2,
|
||||
},
|
||||
},
|
||||
{
|
||||
"dot+doq",
|
||||
&UpstreamConfig{
|
||||
Name: "dot",
|
||||
Type: "dot",
|
||||
Endpoint: "freedns.controld.com:8853",
|
||||
BootstrapIP: "",
|
||||
Domain: "",
|
||||
Timeout: 0,
|
||||
},
|
||||
&UpstreamConfig{
|
||||
Name: "dot",
|
||||
Type: "dot",
|
||||
Endpoint: "freedns.controld.com:8853",
|
||||
BootstrapIP: "",
|
||||
Domain: "freedns.controld.com",
|
||||
Timeout: 0,
|
||||
IPStack: IpStackSplit,
|
||||
},
|
||||
},
|
||||
{
|
||||
"dot+doq without port",
|
||||
&UpstreamConfig{
|
||||
Name: "dot",
|
||||
Type: "dot",
|
||||
Endpoint: "freedns.controld.com",
|
||||
BootstrapIP: "",
|
||||
Domain: "",
|
||||
Timeout: 0,
|
||||
IPStack: IpStackSplit,
|
||||
},
|
||||
&UpstreamConfig{
|
||||
Name: "dot",
|
||||
Type: "dot",
|
||||
Endpoint: "freedns.controld.com:853",
|
||||
BootstrapIP: "",
|
||||
Domain: "freedns.controld.com",
|
||||
Timeout: 0,
|
||||
IPStack: IpStackSplit,
|
||||
},
|
||||
},
|
||||
{
|
||||
"legacy",
|
||||
&UpstreamConfig{
|
||||
Name: "legacy",
|
||||
Type: "legacy",
|
||||
Endpoint: "1.2.3.4:53",
|
||||
BootstrapIP: "",
|
||||
Domain: "",
|
||||
Timeout: 0,
|
||||
},
|
||||
&UpstreamConfig{
|
||||
Name: "legacy",
|
||||
Type: "legacy",
|
||||
Endpoint: "1.2.3.4:53",
|
||||
BootstrapIP: "1.2.3.4",
|
||||
Domain: "1.2.3.4",
|
||||
Timeout: 0,
|
||||
IPStack: IpStackBoth,
|
||||
},
|
||||
},
|
||||
{
|
||||
"legacy without port",
|
||||
&UpstreamConfig{
|
||||
Name: "legacy",
|
||||
Type: "legacy",
|
||||
Endpoint: "1.2.3.4",
|
||||
BootstrapIP: "",
|
||||
Domain: "",
|
||||
Timeout: 0,
|
||||
},
|
||||
&UpstreamConfig{
|
||||
Name: "legacy",
|
||||
Type: "legacy",
|
||||
Endpoint: "1.2.3.4:53",
|
||||
BootstrapIP: "1.2.3.4",
|
||||
Domain: "1.2.3.4",
|
||||
Timeout: 0,
|
||||
IPStack: IpStackBoth,
|
||||
},
|
||||
},
|
||||
{
|
||||
"doh+doh3 with send client info set",
|
||||
&UpstreamConfig{
|
||||
Name: "doh",
|
||||
Type: "doh",
|
||||
Endpoint: "https://example.com?k=v",
|
||||
BootstrapIP: "",
|
||||
Domain: "",
|
||||
Timeout: 0,
|
||||
SendClientInfo: ptrBool(false),
|
||||
IPStack: IpStackBoth,
|
||||
},
|
||||
&UpstreamConfig{
|
||||
Name: "doh",
|
||||
Type: "doh",
|
||||
Endpoint: "https://example.com?k=v",
|
||||
BootstrapIP: "",
|
||||
Domain: "example.com",
|
||||
Timeout: 0,
|
||||
SendClientInfo: ptrBool(false),
|
||||
IPStack: IpStackBoth,
|
||||
u: u2,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
tc.uc.Init()
|
||||
assert.Equal(t, tc.expected, tc.uc)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpstreamConfig_VerifyDomain(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
uc *UpstreamConfig
|
||||
verifyDomain string
|
||||
}{
|
||||
{
|
||||
controlDComDomain,
|
||||
&UpstreamConfig{Endpoint: "https://freedns.controld.com/p2"},
|
||||
controldVerifiedDomain[controlDComDomain],
|
||||
},
|
||||
{
|
||||
controlDDevDomain,
|
||||
&UpstreamConfig{Endpoint: "https://freedns.controld.dev/p2"},
|
||||
controldVerifiedDomain[controlDDevDomain],
|
||||
},
|
||||
{
|
||||
"non-ControlD upstream",
|
||||
&UpstreamConfig{Endpoint: "https://dns.google/dns-query"},
|
||||
"",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
if got := tc.uc.VerifyDomain(); got != tc.verifyDomain {
|
||||
t.Errorf("unexpected verify domain, want: %q, got: %q", tc.verifyDomain, got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
func ptrBool(b bool) *bool {
|
||||
return &b
|
||||
}
|
||||
153
config_quic.go
153
config_quic.go
@@ -5,40 +5,153 @@ package ctrld
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"net"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"github.com/quic-go/quic-go"
|
||||
"github.com/quic-go/quic-go/http3"
|
||||
|
||||
ctrldnet "github.com/Control-D-Inc/ctrld/internal/net"
|
||||
)
|
||||
|
||||
func (uc *UpstreamConfig) setupDOH3Transport() {
|
||||
uc.setupDOH3TransportWithoutPingUpstream()
|
||||
uc.pingUpstream()
|
||||
go uc.pingUpstream()
|
||||
}
|
||||
|
||||
func (uc *UpstreamConfig) newDOH3Transport(addrs []string) http.RoundTripper {
|
||||
rt := &http3.RoundTripper{}
|
||||
rt.TLSClientConfig = &tls.Config{RootCAs: uc.certPool}
|
||||
rt.Dial = func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) {
|
||||
domain := addr
|
||||
_, port, _ := net.SplitHostPort(addr)
|
||||
// if we have a bootstrap ip set, use it to avoid DNS lookup
|
||||
if uc.BootstrapIP != "" {
|
||||
addr = net.JoinHostPort(uc.BootstrapIP, port)
|
||||
ProxyLog.Debug().Msgf("sending doh3 request to: %s", addr)
|
||||
udpConn, err := net.ListenUDP("udp", nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
remoteAddr, err := net.ResolveUDPAddr("udp", addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return quic.DialEarlyContext(ctx, udpConn, remoteAddr, domain, tlsCfg, cfg)
|
||||
}
|
||||
dialAddrs := make([]string, len(addrs))
|
||||
for i := range addrs {
|
||||
dialAddrs[i] = net.JoinHostPort(addrs[i], port)
|
||||
}
|
||||
pd := &quicParallelDialer{}
|
||||
conn, err := pd.Dial(ctx, domain, dialAddrs, tlsCfg, cfg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ProxyLog.Debug().Msgf("sending doh3 request to: %s", conn.RemoteAddr())
|
||||
return conn, err
|
||||
}
|
||||
return rt
|
||||
}
|
||||
|
||||
func (uc *UpstreamConfig) setupDOH3TransportWithoutPingUpstream() {
|
||||
rt := &http3.RoundTripper{}
|
||||
rt.Dial = func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) {
|
||||
host := addr
|
||||
ProxyLog.Debug().Msgf("debug dial context D0H3 %s - %s", addr, bootstrapDNS)
|
||||
// if we have a bootstrap ip set, use it to avoid DNS lookup
|
||||
if uc.BootstrapIP != "" {
|
||||
if _, port, _ := net.SplitHostPort(addr); port != "" {
|
||||
addr = net.JoinHostPort(uc.BootstrapIP, port)
|
||||
}
|
||||
ProxyLog.Debug().Msgf("sending doh3 request to: %s", addr)
|
||||
}
|
||||
remoteAddr, err := net.ResolveUDPAddr("udp", addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
uc.mu.Lock()
|
||||
defer uc.mu.Unlock()
|
||||
switch uc.IPStack {
|
||||
case IpStackBoth, "":
|
||||
uc.http3RoundTripper = uc.newDOH3Transport(uc.bootstrapIPs)
|
||||
case IpStackV4:
|
||||
uc.http3RoundTripper = uc.newDOH3Transport(uc.bootstrapIPs4)
|
||||
case IpStackV6:
|
||||
uc.http3RoundTripper = uc.newDOH3Transport(uc.bootstrapIPs6)
|
||||
case IpStackSplit:
|
||||
uc.http3RoundTripper4 = uc.newDOH3Transport(uc.bootstrapIPs4)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
if ctrldnet.IPv6Available(ctx) {
|
||||
uc.http3RoundTripper6 = uc.newDOH3Transport(uc.bootstrapIPs6)
|
||||
} else {
|
||||
uc.http3RoundTripper6 = uc.http3RoundTripper4
|
||||
}
|
||||
uc.http3RoundTripper = uc.newDOH3Transport(uc.bootstrapIPs)
|
||||
}
|
||||
}
|
||||
|
||||
udpConn, err := net.ListenUDP("udp", nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
func (uc *UpstreamConfig) doh3Transport(dnsType uint16) http.RoundTripper {
|
||||
uc.mu.Lock()
|
||||
defer uc.mu.Unlock()
|
||||
switch uc.IPStack {
|
||||
case IpStackBoth, IpStackV4, IpStackV6:
|
||||
return uc.http3RoundTripper
|
||||
case IpStackSplit:
|
||||
switch dnsType {
|
||||
case dns.TypeA:
|
||||
return uc.http3RoundTripper4
|
||||
default:
|
||||
return uc.http3RoundTripper6
|
||||
}
|
||||
return quic.DialEarlyContext(ctx, udpConn, remoteAddr, host, tlsCfg, cfg)
|
||||
}
|
||||
return uc.http3RoundTripper
|
||||
}
|
||||
|
||||
// Putting the code for quic parallel dialer here:
|
||||
//
|
||||
// - quic dialer is different with net.Dialer
|
||||
// - simplification for quic free version
|
||||
type parallelDialerResult struct {
|
||||
conn quic.EarlyConnection
|
||||
err error
|
||||
}
|
||||
|
||||
type quicParallelDialer struct{}
|
||||
|
||||
// Dial performs parallel dialing to the given address list.
|
||||
func (d *quicParallelDialer) Dial(ctx context.Context, domain string, addrs []string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) {
|
||||
if len(addrs) == 0 {
|
||||
return nil, errors.New("empty addresses")
|
||||
}
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
ch := make(chan *parallelDialerResult, len(addrs))
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(len(addrs))
|
||||
go func() {
|
||||
wg.Wait()
|
||||
close(ch)
|
||||
}()
|
||||
|
||||
udpConn, err := net.ListenUDP("udp", nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
uc.http3RoundTripper = rt
|
||||
for _, addr := range addrs {
|
||||
go func(addr string) {
|
||||
defer wg.Done()
|
||||
remoteAddr, err := net.ResolveUDPAddr("udp", addr)
|
||||
if err != nil {
|
||||
ch <- ¶llelDialerResult{conn: nil, err: err}
|
||||
return
|
||||
}
|
||||
|
||||
conn, err := quic.DialEarlyContext(ctx, udpConn, remoteAddr, domain, tlsCfg, cfg)
|
||||
ch <- ¶llelDialerResult{conn: conn, err: err}
|
||||
}(addr)
|
||||
}
|
||||
|
||||
errs := make([]error, 0, len(addrs))
|
||||
for res := range ch {
|
||||
if res.err == nil {
|
||||
cancel()
|
||||
return res.conn, res.err
|
||||
}
|
||||
errs = append(errs, res.err)
|
||||
}
|
||||
|
||||
return nil, errors.Join(errs...)
|
||||
}
|
||||
|
||||
@@ -2,6 +2,9 @@
|
||||
|
||||
package ctrld
|
||||
|
||||
import "net/http"
|
||||
|
||||
func (uc *UpstreamConfig) setupDOH3Transport() {}
|
||||
|
||||
func (uc *UpstreamConfig) setupDOH3TransportWithoutPingUpstream() {}
|
||||
func (uc *UpstreamConfig) setupDOH3TransportWithoutPingUpstream() {}
|
||||
func (uc *UpstreamConfig) doh3Transport(dnsType uint16) http.RoundTripper { return nil }
|
||||
|
||||
128
config_test.go
128
config_test.go
@@ -24,10 +24,12 @@ func TestLoadConfig(t *testing.T) {
|
||||
assert.Contains(t, cfg.Network, "0")
|
||||
assert.Contains(t, cfg.Network, "1")
|
||||
|
||||
assert.Len(t, cfg.Upstream, 3)
|
||||
assert.Len(t, cfg.Upstream, 4)
|
||||
assert.Contains(t, cfg.Upstream, "0")
|
||||
assert.Contains(t, cfg.Upstream, "1")
|
||||
assert.Contains(t, cfg.Upstream, "2")
|
||||
assert.Contains(t, cfg.Upstream, "3")
|
||||
assert.NotNil(t, cfg.Upstream["3"].SendClientInfo)
|
||||
|
||||
assert.Len(t, cfg.Listener, 2)
|
||||
assert.Contains(t, cfg.Listener, "0")
|
||||
@@ -42,6 +44,8 @@ func TestLoadConfig(t *testing.T) {
|
||||
assert.Len(t, cfg.Listener["0"].Policy.Rules, 2)
|
||||
assert.Contains(t, cfg.Listener["0"].Policy.Rules[0], "*.ru")
|
||||
assert.Contains(t, cfg.Listener["0"].Policy.Rules[1], "*.local.host")
|
||||
|
||||
assert.True(t, cfg.HasUpstreamSendClientInfo())
|
||||
}
|
||||
|
||||
func TestLoadDefaultConfig(t *testing.T) {
|
||||
@@ -61,6 +65,7 @@ func TestConfigValidation(t *testing.T) {
|
||||
{"invalid Config", &ctrld.Config{}, true},
|
||||
{"default Config", defaultConfig(t), false},
|
||||
{"sample Config", testhelper.SampleConfig(t), false},
|
||||
{"empty listener IP", emptyListenerIP(t), false},
|
||||
{"invalid cidr", invalidNetworkConfig(t), true},
|
||||
{"invalid upstream type", invalidUpstreamType(t), true},
|
||||
{"invalid upstream timeout", invalidUpstreamTimeout(t), true},
|
||||
@@ -130,9 +135,15 @@ func invalidListenerIP(t *testing.T) *ctrld.Config {
|
||||
return cfg
|
||||
}
|
||||
|
||||
func emptyListenerIP(t *testing.T) *ctrld.Config {
|
||||
cfg := defaultConfig(t)
|
||||
cfg.Listener["0"].IP = ""
|
||||
return cfg
|
||||
}
|
||||
|
||||
func invalidListenerPort(t *testing.T) *ctrld.Config {
|
||||
cfg := defaultConfig(t)
|
||||
cfg.Listener["0"].Port = 0
|
||||
cfg.Listener["0"].Port = -1
|
||||
return cfg
|
||||
}
|
||||
|
||||
@@ -165,116 +176,3 @@ func configWithInvalidRcodes(t *testing.T) *ctrld.Config {
|
||||
}
|
||||
return cfg
|
||||
}
|
||||
|
||||
func TestUpstreamConfig_Init(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
uc *ctrld.UpstreamConfig
|
||||
expected *ctrld.UpstreamConfig
|
||||
}{
|
||||
{
|
||||
"doh+doh3",
|
||||
&ctrld.UpstreamConfig{
|
||||
Name: "doh",
|
||||
Type: "doh",
|
||||
Endpoint: "https://example.com",
|
||||
BootstrapIP: "",
|
||||
Domain: "",
|
||||
Timeout: 0,
|
||||
},
|
||||
&ctrld.UpstreamConfig{
|
||||
Name: "doh",
|
||||
Type: "doh",
|
||||
Endpoint: "https://example.com",
|
||||
BootstrapIP: "",
|
||||
Domain: "example.com",
|
||||
Timeout: 0,
|
||||
},
|
||||
},
|
||||
{
|
||||
"dot+doq",
|
||||
&ctrld.UpstreamConfig{
|
||||
Name: "dot",
|
||||
Type: "dot",
|
||||
Endpoint: "freedns.controld.com:8853",
|
||||
BootstrapIP: "",
|
||||
Domain: "",
|
||||
Timeout: 0,
|
||||
},
|
||||
&ctrld.UpstreamConfig{
|
||||
Name: "dot",
|
||||
Type: "dot",
|
||||
Endpoint: "freedns.controld.com:8853",
|
||||
BootstrapIP: "",
|
||||
Domain: "freedns.controld.com",
|
||||
Timeout: 0,
|
||||
},
|
||||
},
|
||||
{
|
||||
"dot+doq without port",
|
||||
&ctrld.UpstreamConfig{
|
||||
Name: "dot",
|
||||
Type: "dot",
|
||||
Endpoint: "freedns.controld.com",
|
||||
BootstrapIP: "",
|
||||
Domain: "",
|
||||
Timeout: 0,
|
||||
},
|
||||
&ctrld.UpstreamConfig{
|
||||
Name: "dot",
|
||||
Type: "dot",
|
||||
Endpoint: "freedns.controld.com:853",
|
||||
BootstrapIP: "",
|
||||
Domain: "freedns.controld.com",
|
||||
Timeout: 0,
|
||||
},
|
||||
},
|
||||
{
|
||||
"legacy",
|
||||
&ctrld.UpstreamConfig{
|
||||
Name: "legacy",
|
||||
Type: "legacy",
|
||||
Endpoint: "1.2.3.4:53",
|
||||
BootstrapIP: "",
|
||||
Domain: "",
|
||||
Timeout: 0,
|
||||
},
|
||||
&ctrld.UpstreamConfig{
|
||||
Name: "legacy",
|
||||
Type: "legacy",
|
||||
Endpoint: "1.2.3.4:53",
|
||||
BootstrapIP: "1.2.3.4",
|
||||
Domain: "1.2.3.4",
|
||||
Timeout: 0,
|
||||
},
|
||||
},
|
||||
{
|
||||
"legacy without port",
|
||||
&ctrld.UpstreamConfig{
|
||||
Name: "legacy",
|
||||
Type: "legacy",
|
||||
Endpoint: "1.2.3.4",
|
||||
BootstrapIP: "",
|
||||
Domain: "",
|
||||
Timeout: 0,
|
||||
},
|
||||
&ctrld.UpstreamConfig{
|
||||
Name: "legacy",
|
||||
Type: "legacy",
|
||||
Endpoint: "1.2.3.4:53",
|
||||
BootstrapIP: "1.2.3.4",
|
||||
Domain: "1.2.3.4",
|
||||
Timeout: 0,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
tc.uc.Init()
|
||||
assert.Equal(t, tc.expected, tc.uc)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
101
docs/config.md
101
docs/config.md
@@ -14,10 +14,15 @@ The config file allows for advanced configuration of the `ctrld` utility to cove
|
||||
|
||||
|
||||
## Config Location
|
||||
`ctrld` uses [TOML](toml_link) format for its configuration file. Default configuration file is `config.toml` found in following order:
|
||||
`ctrld` uses [TOML](toml_link) format for its configuration file. Default configuration file is `ctrld.toml` found in following order:
|
||||
|
||||
- `$HOME/.ctrld`
|
||||
- Current directory
|
||||
- `/etc/controld` on *nix.
|
||||
- User's home directory on Windows.
|
||||
- Same directory with `ctrld` binary on these routers:
|
||||
- `ddwrt`
|
||||
- `merlin`
|
||||
- `freshtomato`
|
||||
- Current directory.
|
||||
|
||||
The user can choose to override default value using command line `--config` or `-c`:
|
||||
|
||||
@@ -38,6 +43,8 @@ if it's existed.
|
||||
log_path = ""
|
||||
cache_enable = true
|
||||
cache_size = 4096
|
||||
cache_ttl_override = 60
|
||||
cache_serve_stale = true
|
||||
|
||||
[network.0]
|
||||
cidrs = ["0.0.0.0/0"]
|
||||
@@ -53,6 +60,7 @@ if it's existed.
|
||||
name = "Control D - Anti-Malware"
|
||||
timeout = 5000
|
||||
type = "doh"
|
||||
ip_stack = "both"
|
||||
|
||||
[upstream.1]
|
||||
bootstrap_ip = "76.76.2.11"
|
||||
@@ -60,6 +68,7 @@ if it's existed.
|
||||
name = "Control D - No Ads"
|
||||
timeout = 5000
|
||||
type = "doq"
|
||||
ip_stack = "split"
|
||||
|
||||
[upstream.2]
|
||||
bootstrap_ip = "76.76.2.22"
|
||||
@@ -67,6 +76,7 @@ if it's existed.
|
||||
name = "Control D - Private"
|
||||
timeout = 5000
|
||||
type = "dot"
|
||||
ip_stack = "v4"
|
||||
|
||||
[listener.0]
|
||||
ip = "127.0.0.1"
|
||||
@@ -104,8 +114,8 @@ Logging level you wish to enable.
|
||||
|
||||
- Type: string
|
||||
- Required: no
|
||||
- Valid values: `debug`, `info`, `warn`, `error`, `fatal`, `panic`
|
||||
- Default: `info`
|
||||
- Valid values: `debug`, `info`, `warn`, `notice`, `error`, `fatal`, `panic`
|
||||
- Default: `notice`
|
||||
|
||||
|
||||
### log_path
|
||||
@@ -113,12 +123,14 @@ Relative or absolute path of the log file.
|
||||
|
||||
- Type: string
|
||||
- Required: no
|
||||
- Default: ""
|
||||
|
||||
### cache_enable
|
||||
When `cache_enable = true`, all resolved DNS query responses will be cached for duration of the upstream record TTLs.
|
||||
|
||||
- Type: boolean
|
||||
- Required: no
|
||||
- Default: false
|
||||
|
||||
### cache_size
|
||||
The number of cached records, must be a positive integer. Tweaking this value with care depends on your available RAM.
|
||||
@@ -128,29 +140,22 @@ An invalid `cache_size` value will disable the cache, regardless of `cache_enabl
|
||||
|
||||
- Type: int
|
||||
- Required: no
|
||||
- Default: 4096
|
||||
|
||||
### cache_ttl_override
|
||||
When `cache_ttl_override` is set to a positive value (in seconds), TTLs are overridden to this value and cached for this long.
|
||||
|
||||
- Type: int
|
||||
- Required: no
|
||||
- Default: 0
|
||||
|
||||
### cache_serve_stale
|
||||
When `cache_serve_stale = true`, in cases of upstream failures (upstreams not reachable), `ctrld` will keep serving
|
||||
stale cached records (regardless of their TTLs) until upstream comes online.
|
||||
|
||||
The above config will look like this at query time.
|
||||
|
||||
```
|
||||
2022-11-14T22:18:53.808 INF Setting bootstrap IP for upstream.0 bootstrap_ip=76.76.2.11
|
||||
2022-11-14T22:18:53.808 INF Starting DNS server on listener.0: 127.0.0.1:53
|
||||
2022-11-14T22:18:56.381 DBG [9fd5d3] 127.0.0.1:53978 -> listener.0: 127.0.0.1:53: received query: verify.controld.com
|
||||
2022-11-14T22:18:56.381 INF [9fd5d3] no policy, no network, no rule -> [upstream.0]
|
||||
2022-11-14T22:18:56.381 DBG [9fd5d3] sending query to upstream.0: Control D - DOH Free
|
||||
2022-11-14T22:18:56.381 DBG [9fd5d3] debug dial context freedns.controld.com:443 - tcp - 76.76.2.0
|
||||
2022-11-14T22:18:56.381 DBG [9fd5d3] sending doh request to: 76.76.2.11:443
|
||||
2022-11-14T22:18:56.420 DBG [9fd5d3] received response of 118 bytes in 39.662597ms
|
||||
```
|
||||
- Type: boolean
|
||||
- Required: no
|
||||
- Default: false
|
||||
|
||||
## Upstream
|
||||
The `[upstream]` section specifies the DNS upstream servers that `ctrld` will forward DNS requests to.
|
||||
@@ -162,6 +167,7 @@ The `[upstream]` section specifies the DNS upstream servers that `ctrld` will fo
|
||||
name = "Control D - DOH"
|
||||
timeout = 5000
|
||||
type = "doh"
|
||||
ip_stack = "split"
|
||||
|
||||
[upstream.1]
|
||||
bootstrap_ip = ""
|
||||
@@ -169,6 +175,7 @@ The `[upstream]` section specifies the DNS upstream servers that `ctrld` will fo
|
||||
name = "Control D - DOH3"
|
||||
timeout = 5000
|
||||
type = "doh3"
|
||||
ip_stack = "both"
|
||||
|
||||
[upstream.2]
|
||||
bootstrap_ip = ""
|
||||
@@ -176,6 +183,7 @@ The `[upstream]` section specifies the DNS upstream servers that `ctrld` will fo
|
||||
name = "Controld D - DOT"
|
||||
timeout = 5000
|
||||
type = "dot"
|
||||
ip_stack = "v4"
|
||||
|
||||
[upstream.3]
|
||||
bootstrap_ip = ""
|
||||
@@ -183,6 +191,7 @@ The `[upstream]` section specifies the DNS upstream servers that `ctrld` will fo
|
||||
name = "Controld D - DOT"
|
||||
timeout = 5000
|
||||
type = "doq"
|
||||
ip_stack = "v6"
|
||||
|
||||
[upstream.4]
|
||||
bootstrap_ip = ""
|
||||
@@ -190,6 +199,7 @@ The `[upstream]` section specifies the DNS upstream servers that `ctrld` will fo
|
||||
name = "Control D - Ad Blocking"
|
||||
timeout = 5000
|
||||
type = "legacy"
|
||||
ip_stack = "both"
|
||||
```
|
||||
|
||||
### bootstrap_ip
|
||||
@@ -200,6 +210,7 @@ If `bootstrap_ip` is empty, `ctrld` will resolve this itself using its own boots
|
||||
|
||||
- type: ip address string
|
||||
- required: no
|
||||
- Default: ""
|
||||
|
||||
### endpoint
|
||||
IP address, hostname or URL of upstream DNS. Used together with `Type` of the endpoint.
|
||||
@@ -214,6 +225,7 @@ Human-readable name of the upstream.
|
||||
|
||||
- Type: string
|
||||
- Required: no
|
||||
- Default: ""
|
||||
|
||||
### timeout
|
||||
Timeout in milliseconds before request failsover to the next upstream (if defined).
|
||||
@@ -221,15 +233,34 @@ Timeout in milliseconds before request failsover to the next upstream (if define
|
||||
Value `0` means no timeout.
|
||||
|
||||
- Type: number
|
||||
- required: no
|
||||
- Required: no
|
||||
- Default: 0
|
||||
|
||||
### type
|
||||
The protocol that `ctrld` will use to send DNS requests to upstream.
|
||||
|
||||
- Type: string
|
||||
- required: yes
|
||||
- Required: yes
|
||||
- Valid values: `doh`, `doh3`, `dot`, `doq`, `legacy`, `os`
|
||||
|
||||
### ip_stack
|
||||
Specifying what kind of ip stack that `ctrld` will use to connect to upstream.
|
||||
|
||||
- Type: string
|
||||
- Required: no
|
||||
- Valid values:
|
||||
- `both`: using either ipv4 or ipv6.
|
||||
- `v4`: only dial upstream via IPv4, never dial IPv6.
|
||||
- `v6`: only dial upstream via IPv6, never dial IPv4.
|
||||
- `split`:
|
||||
- If `A` record is requested -> dial via ipv4.
|
||||
- If `AAAA` or any other record is requested -> dial ipv6 (if available, otherwise ipv4)
|
||||
|
||||
If `ip_stack` is empty, or undefined:
|
||||
|
||||
- Default value is `both` for non-Control D resolvers.
|
||||
- Default value is `split` for Control D resolvers.
|
||||
|
||||
## Network
|
||||
The `[network]` section defines networks from which DNS queries can originate from. These are used in policies. You can define multiple networks, and each one can have multiple cidrs.
|
||||
|
||||
@@ -248,12 +279,14 @@ Name of the network.
|
||||
|
||||
- Type: string
|
||||
- Required: no
|
||||
- Default: ""
|
||||
|
||||
### cidrs
|
||||
Specifies the network addresses that the `listener` will accept requests from. You will see more details in the listener policy section.
|
||||
|
||||
- Type: array of network CIDR string
|
||||
- Required: no
|
||||
- Default: []
|
||||
|
||||
|
||||
## listener
|
||||
@@ -271,22 +304,25 @@ The `[listener]` section specifies the ip and port of the local DNS server. You
|
||||
```
|
||||
|
||||
### ip
|
||||
IP address that serves the incoming requests.
|
||||
IP address that serves the incoming requests. If `ip` is empty, ctrld will listen on all available addresses.
|
||||
|
||||
- Type: string
|
||||
- Required: yes
|
||||
- Type: ip address string
|
||||
- Required: no
|
||||
- Default: ""
|
||||
|
||||
### port
|
||||
Port number that the listener will listen on for incoming requests.
|
||||
Port number that the listener will listen on for incoming requests. If `port` is `0`, a random available port will be chosen.
|
||||
|
||||
- Type: number
|
||||
- Required: yes
|
||||
- Required: no
|
||||
- Default: 0
|
||||
|
||||
### restricted
|
||||
If set to `true` makes the listener `REFUSE` DNS queries from all source IP addresses that are not explicitly defined in the policy using a `network`.
|
||||
|
||||
- Type: bool
|
||||
- Required: no
|
||||
- Default: false
|
||||
|
||||
### policy
|
||||
Allows `ctrld` to set policy rules to determine which upstreams the requests will be forwarded to.
|
||||
@@ -330,19 +366,30 @@ rules = [
|
||||
|
||||
- Type: string
|
||||
- Required: no
|
||||
- Default: ""
|
||||
|
||||
### networks:
|
||||
`networks` is the list of network rules of the policy.
|
||||
|
||||
- type: array of networks
|
||||
- Type: array of networks
|
||||
- Required: no
|
||||
- Default: []
|
||||
|
||||
### rules:
|
||||
`rules` is the list of domain rules within the policy. Domain can be either FQDN or wildcard domain.
|
||||
|
||||
- type: array of rule
|
||||
- Type: array of rule
|
||||
- Required: no
|
||||
- Default: []
|
||||
|
||||
### failover_rcodes
|
||||
For non success response, `failover_rcodes` allows the request to be forwarded to next upstream, if the response `RCODE` matches any value defined in `failover_rcodes`. For example:
|
||||
For non success response, `failover_rcodes` allows the request to be forwarded to next upstream, if the response `RCODE` matches any value defined in `failover_rcodes`.
|
||||
|
||||
- Type: array of string
|
||||
- Required: no
|
||||
- Default: []
|
||||
-
|
||||
For example:
|
||||
|
||||
```toml
|
||||
[listener.0.policy]
|
||||
|
||||
68
doh.go
68
doh.go
@@ -7,25 +7,35 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
const (
|
||||
dohMacHeader = "x-cd-mac"
|
||||
dohIPHeader = "x-cd-ip"
|
||||
dohHostHeader = "x-cd-host"
|
||||
headerApplicationDNS = "application/dns-message"
|
||||
)
|
||||
|
||||
func newDohResolver(uc *UpstreamConfig) *dohResolver {
|
||||
r := &dohResolver{
|
||||
endpoint: uc.Endpoint,
|
||||
endpoint: uc.u,
|
||||
isDoH3: uc.Type == ResolverTypeDOH3,
|
||||
transport: uc.transport,
|
||||
http3RoundTripper: uc.http3RoundTripper,
|
||||
sendClientInfo: uc.UpstreamSendClientInfo(),
|
||||
uc: uc,
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
type dohResolver struct {
|
||||
endpoint string
|
||||
uc *UpstreamConfig
|
||||
endpoint *url.URL
|
||||
isDoH3 bool
|
||||
transport *http.Transport
|
||||
http3RoundTripper http.RoundTripper
|
||||
sendClientInfo bool
|
||||
}
|
||||
|
||||
func (r *dohResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) {
|
||||
@@ -33,26 +43,34 @@ func (r *dohResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, erro
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
enc := base64.RawURLEncoding.EncodeToString(data)
|
||||
url := fmt.Sprintf("%s?dns=%s", r.endpoint, enc)
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
||||
query := r.endpoint.Query()
|
||||
query.Add("dns", enc)
|
||||
|
||||
endpoint := *r.endpoint
|
||||
endpoint.RawQuery = query.Encode()
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint.String(), nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not create request: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/dns-message")
|
||||
req.Header.Set("Accept", "application/dns-message")
|
||||
|
||||
c := http.Client{Transport: r.transport}
|
||||
addHeader(ctx, req, r.sendClientInfo)
|
||||
dnsTyp := uint16(0)
|
||||
if len(msg.Question) > 0 {
|
||||
dnsTyp = msg.Question[0].Qtype
|
||||
}
|
||||
c := http.Client{Transport: r.uc.dohTransport(dnsTyp)}
|
||||
if r.isDoH3 {
|
||||
if r.http3RoundTripper == nil {
|
||||
transport := r.uc.doh3Transport(dnsTyp)
|
||||
if transport == nil {
|
||||
return nil, errors.New("DoH3 is not supported")
|
||||
}
|
||||
c.Transport = r.http3RoundTripper
|
||||
c.Transport = transport
|
||||
}
|
||||
resp, err := c.Do(req)
|
||||
if err != nil {
|
||||
if r.isDoH3 {
|
||||
if closer, ok := r.http3RoundTripper.(io.Closer); ok {
|
||||
if closer, ok := c.Transport.(io.Closer); ok {
|
||||
closer.Close()
|
||||
}
|
||||
}
|
||||
@@ -70,5 +88,27 @@ func (r *dohResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, erro
|
||||
}
|
||||
|
||||
answer := new(dns.Msg)
|
||||
return answer, answer.Unpack(buf)
|
||||
if err := answer.Unpack(buf); err != nil {
|
||||
return nil, fmt.Errorf("answer.Unpack: %w", err)
|
||||
}
|
||||
return answer, nil
|
||||
}
|
||||
|
||||
func addHeader(ctx context.Context, req *http.Request, sendClientInfo bool) {
|
||||
req.Header.Set("Content-Type", headerApplicationDNS)
|
||||
req.Header.Set("Accept", headerApplicationDNS)
|
||||
if sendClientInfo {
|
||||
if ci, ok := ctx.Value(ClientInfoCtxKey{}).(*ClientInfo); ok && ci != nil {
|
||||
if ci.Mac != "" {
|
||||
req.Header.Set(dohMacHeader, ci.Mac)
|
||||
}
|
||||
if ci.IP != "" {
|
||||
req.Header.Set(dohIPHeader, ci.IP)
|
||||
}
|
||||
if ci.Hostname != "" {
|
||||
req.Header.Set(dohHostHeader, ci.Hostname)
|
||||
}
|
||||
}
|
||||
}
|
||||
Log(ctx, ProxyLog.Debug().Interface("header", req.Header), "sending request header")
|
||||
}
|
||||
|
||||
14
doq.go
14
doq.go
@@ -20,11 +20,17 @@ type doqResolver struct {
|
||||
func (r *doqResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) {
|
||||
endpoint := r.uc.Endpoint
|
||||
tlsConfig := &tls.Config{NextProtos: []string{"doq"}}
|
||||
if r.uc.BootstrapIP != "" {
|
||||
tlsConfig.ServerName = r.uc.Domain
|
||||
_, port, _ := net.SplitHostPort(endpoint)
|
||||
endpoint = net.JoinHostPort(r.uc.BootstrapIP, port)
|
||||
ip := r.uc.BootstrapIP
|
||||
if ip == "" {
|
||||
dnsTyp := uint16(0)
|
||||
if msg != nil && len(msg.Question) > 0 {
|
||||
dnsTyp = msg.Question[0].Qtype
|
||||
}
|
||||
ip = r.uc.bootstrapIPForDNSType(dnsTyp)
|
||||
}
|
||||
tlsConfig.ServerName = r.uc.Domain
|
||||
_, port, _ := net.SplitHostPort(endpoint)
|
||||
endpoint = net.JoinHostPort(ip, port)
|
||||
return resolve(ctx, msg, endpoint, tlsConfig)
|
||||
}
|
||||
|
||||
|
||||
16
dot.go
16
dot.go
@@ -14,18 +14,26 @@ type dotResolver struct {
|
||||
|
||||
func (r *dotResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) {
|
||||
// The dialer is used to prevent bootstrapping cycle.
|
||||
// If r.endpoing is set to dns.controld.dev, we need to resolve
|
||||
// If r.endpoint is set to dns.controld.dev, we need to resolve
|
||||
// dns.controld.dev first. By using a dialer with custom resolver,
|
||||
// we ensure that we can always resolve the bootstrap domain
|
||||
// regardless of the machine DNS status.
|
||||
dialer := newDialer(net.JoinHostPort(bootstrapDNS, "53"))
|
||||
dnsTyp := uint16(0)
|
||||
if msg != nil && len(msg.Question) > 0 {
|
||||
dnsTyp = msg.Question[0].Qtype
|
||||
}
|
||||
|
||||
tcpNet, _ := r.uc.netForDNSType(dnsTyp)
|
||||
dnsClient := &dns.Client{
|
||||
Net: "tcp-tls",
|
||||
Dialer: dialer,
|
||||
Net: tcpNet,
|
||||
Dialer: dialer,
|
||||
TLSConfig: &tls.Config{RootCAs: r.uc.certPool},
|
||||
}
|
||||
endpoint := r.uc.Endpoint
|
||||
if r.uc.BootstrapIP != "" {
|
||||
dnsClient.TLSConfig = &tls.Config{ServerName: r.uc.Domain}
|
||||
dnsClient.TLSConfig.ServerName = r.uc.Domain
|
||||
dnsClient.Net = "tcp-tls"
|
||||
_, port, _ := net.SplitHostPort(endpoint)
|
||||
endpoint = net.JoinHostPort(r.uc.BootstrapIP, port)
|
||||
}
|
||||
|
||||
43
errors.go
43
errors.go
@@ -1,43 +0,0 @@
|
||||
package ctrld
|
||||
|
||||
// TODO(cuonglm): use stdlib once we bump minimum version to 1.20
|
||||
|
||||
func joinErrors(errs ...error) error {
|
||||
n := 0
|
||||
for _, err := range errs {
|
||||
if err != nil {
|
||||
n++
|
||||
}
|
||||
}
|
||||
if n == 0 {
|
||||
return nil
|
||||
}
|
||||
e := &joinError{
|
||||
errs: make([]error, 0, n),
|
||||
}
|
||||
for _, err := range errs {
|
||||
if err != nil {
|
||||
e.errs = append(e.errs, err)
|
||||
}
|
||||
}
|
||||
return e
|
||||
}
|
||||
|
||||
type joinError struct {
|
||||
errs []error
|
||||
}
|
||||
|
||||
func (e *joinError) Error() string {
|
||||
var b []byte
|
||||
for i, err := range e.errs {
|
||||
if i > 0 {
|
||||
b = append(b, '\n')
|
||||
}
|
||||
b = append(b, err.Error()...)
|
||||
}
|
||||
return string(b)
|
||||
}
|
||||
|
||||
func (e *joinError) Unwrap() []error {
|
||||
return e.errs
|
||||
}
|
||||
31
go.mod
31
go.mod
@@ -1,16 +1,18 @@
|
||||
module github.com/Control-D-Inc/ctrld
|
||||
|
||||
go 1.19
|
||||
go 1.20
|
||||
|
||||
require (
|
||||
github.com/coreos/go-systemd/v22 v22.3.3-0.20220203105225-a9a7ef127534
|
||||
github.com/coreos/go-systemd/v22 v22.5.0
|
||||
github.com/cuonglm/osinfo v0.0.0-20230329055532-c513f836da19
|
||||
github.com/frankban/quicktest v1.14.3
|
||||
github.com/fsnotify/fsnotify v1.6.0
|
||||
github.com/go-playground/validator/v10 v10.11.1
|
||||
github.com/godbus/dbus/v5 v5.0.6
|
||||
github.com/hashicorp/golang-lru/v2 v2.0.1
|
||||
github.com/illarion/gonotify v1.0.1
|
||||
github.com/insomniacslk/dhcp v0.0.0-20211209223715-7d93572ebe8e
|
||||
github.com/insomniacslk/dhcp v0.0.0-20221215072855-de60144f33f8
|
||||
github.com/josharian/native v1.1.1-0.20230202152459-5c7d0dd6ab86
|
||||
github.com/kardianos/service v1.2.1
|
||||
github.com/miekg/dns v1.1.50
|
||||
github.com/pelletier/go-toml/v2 v2.0.6
|
||||
@@ -19,10 +21,12 @@ require (
|
||||
github.com/spf13/cobra v1.4.0
|
||||
github.com/spf13/viper v1.14.0
|
||||
github.com/stretchr/testify v1.8.1
|
||||
github.com/vishvananda/netlink v1.1.1-0.20211118161826-650dca95af54
|
||||
golang.org/x/net v0.7.0
|
||||
golang.org/x/sync v0.1.0
|
||||
golang.org/x/sys v0.5.0
|
||||
golang.zx2c4.com/wireguard/windows v0.5.3
|
||||
tailscale.com v1.34.1
|
||||
tailscale.com v1.38.3
|
||||
)
|
||||
|
||||
require (
|
||||
@@ -36,7 +40,6 @@ require (
|
||||
github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38 // indirect
|
||||
github.com/hashicorp/hcl v1.0.0 // indirect
|
||||
github.com/inconshreveable/mousetrap v1.0.0 // indirect
|
||||
github.com/josharian/native v1.0.0 // indirect
|
||||
github.com/jsimonetti/rtnetlink v1.1.2-0.20220408201609-d380b505068b // indirect
|
||||
github.com/kr/pretty v0.3.0 // indirect
|
||||
github.com/kr/text v0.2.0 // indirect
|
||||
@@ -45,9 +48,9 @@ require (
|
||||
github.com/mattn/go-colorable v0.1.12 // indirect
|
||||
github.com/mattn/go-isatty v0.0.14 // indirect
|
||||
github.com/mdlayher/ethernet v0.0.0-20190606142754-0394541c37b7 // indirect
|
||||
github.com/mdlayher/netlink v1.6.0 // indirect
|
||||
github.com/mdlayher/netlink v1.7.1 // indirect
|
||||
github.com/mdlayher/raw v0.0.0-20191009151244-50f2db8cc065 // indirect
|
||||
github.com/mdlayher/socket v0.2.3 // indirect
|
||||
github.com/mdlayher/socket v0.4.0 // indirect
|
||||
github.com/mitchellh/mapstructure v1.5.0 // indirect
|
||||
github.com/onsi/ginkgo/v2 v2.2.0 // indirect
|
||||
github.com/pelletier/go-toml v1.9.5 // indirect
|
||||
@@ -62,15 +65,19 @@ require (
|
||||
github.com/spf13/jwalterweatherman v1.1.0 // indirect
|
||||
github.com/spf13/pflag v1.0.5 // indirect
|
||||
github.com/subosito/gotenv v1.4.1 // indirect
|
||||
github.com/u-root/uio v0.0.0-20220204230159-dac05f7d2cb4 // indirect
|
||||
github.com/u-root/uio v0.0.0-20221213070652-c3537552635f // indirect
|
||||
github.com/vishvananda/netns v0.0.0-20211101163701-50045581ed74 // indirect
|
||||
go4.org/mem v0.0.0-20210711025021-927187094b94 // indirect
|
||||
golang.org/x/crypto v0.4.0 // indirect
|
||||
golang.org/x/crypto v0.6.0 // indirect
|
||||
golang.org/x/exp v0.0.0-20221205204356-47842c84f3db // indirect
|
||||
golang.org/x/mod v0.6.0 // indirect
|
||||
golang.org/x/net v0.7.0 // indirect
|
||||
golang.org/x/mod v0.7.0 // indirect
|
||||
golang.org/x/text v0.7.0 // indirect
|
||||
golang.org/x/tools v0.2.0 // indirect
|
||||
golang.org/x/tools v0.4.1-0.20221208213631-3f74d914ae6d // indirect
|
||||
gopkg.in/ini.v1 v1.67.0 // indirect
|
||||
gopkg.in/yaml.v2 v2.4.0 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
)
|
||||
|
||||
replace github.com/mr-karan/doggo => github.com/Windscribe/doggo v0.0.0-20220919152748-2c118fc391f8
|
||||
|
||||
replace github.com/rs/zerolog => github.com/Windscribe/zerolog v0.0.0-20230503170159-e6aa153233be
|
||||
|
||||
58
go.sum
58
go.sum
@@ -38,6 +38,8 @@ cloud.google.com/go/storage v1.14.0/go.mod h1:GrKmX003DSIwi9o29oFT7YDnHYwZoctc3f
|
||||
dmitri.shuralyov.com/gpu/mtl v0.0.0-20190408044501-666a987793e9/go.mod h1:H6x//7gZCb22OMCxBHrMx7a5I7Hp++hsVxbQ4BYO7hU=
|
||||
github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU=
|
||||
github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo=
|
||||
github.com/Windscribe/zerolog v0.0.0-20230503170159-e6aa153233be h1:qBKVRi7Mom5heOkyZ+NCIu9HZBiNCsRqrRe5t9pooik=
|
||||
github.com/Windscribe/zerolog v0.0.0-20230503170159-e6aa153233be/go.mod h1:/tk+P47gFdPXq4QYjvCmT5/Gsug2nagsFWBWhAiSi1w=
|
||||
github.com/alexbrainman/sspi v0.0.0-20210105120005-909beea2cc74 h1:Kk6a4nehpJ3UuJRqlA3JxYxBZEqCeOmATOvrbT4p9RA=
|
||||
github.com/alexbrainman/sspi v0.0.0-20210105120005-909beea2cc74/go.mod h1:cEWa1LVoE5KvSD9ONXsZrj0z6KqySlCCNKHlLzbqAt4=
|
||||
github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU=
|
||||
@@ -50,10 +52,12 @@ github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDk
|
||||
github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc=
|
||||
github.com/cncf/udpa/go v0.0.0-20200629203442-efcf912fb354/go.mod h1:WmhPx2Nbnhtbo57+VJT5O0JRkEi1Wbu0z5j0R8u5Hbk=
|
||||
github.com/cncf/udpa/go v0.0.0-20201120205902-5459f2c99403/go.mod h1:WmhPx2Nbnhtbo57+VJT5O0JRkEi1Wbu0z5j0R8u5Hbk=
|
||||
github.com/coreos/go-systemd/v22 v22.3.3-0.20220203105225-a9a7ef127534 h1:rtAn27wIbmOGUs7RIbVgPEjb31ehTVniDwPGXyMxm5U=
|
||||
github.com/coreos/go-systemd/v22 v22.3.3-0.20220203105225-a9a7ef127534/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc=
|
||||
github.com/coreos/go-systemd/v22 v22.5.0 h1:RrqgGjYQKalulkV8NGVIfkXQf6YYmOyiJKk8iXXhfZs=
|
||||
github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc=
|
||||
github.com/cpuguy83/go-md2man/v2 v2.0.1/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o=
|
||||
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
|
||||
github.com/cuonglm/osinfo v0.0.0-20230329055532-c513f836da19 h1:7P/f19Mr0oa3ug8BYt4JuRe/Zq3dF4Mrr4m8+Kw+Hcs=
|
||||
github.com/cuonglm/osinfo v0.0.0-20230329055532-c513f836da19/go.mod h1:G45410zMgmnSjLVKCq4f6GpbYAzoP2plX9rPwgx6C24=
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
@@ -163,10 +167,12 @@ github.com/illarion/gonotify v1.0.1 h1:F1d+0Fgbq/sDWjj/r66ekjDG+IDeecQKUFH4wNwso
|
||||
github.com/illarion/gonotify v1.0.1/go.mod h1:zt5pmDofZpU1f8aqlK0+95eQhoEAn/d4G4B/FjVW4jE=
|
||||
github.com/inconshreveable/mousetrap v1.0.0 h1:Z8tu5sraLXCXIcARxBp/8cbvlwVa7Z1NHg9XEKhtSvM=
|
||||
github.com/inconshreveable/mousetrap v1.0.0/go.mod h1:PxqpIevigyE2G7u3NXJIT2ANytuPF1OarO4DADm73n8=
|
||||
github.com/insomniacslk/dhcp v0.0.0-20211209223715-7d93572ebe8e h1:IQpunlq7T+NiJJMO7ODYV2YWBiv/KnObR3gofX0mWOo=
|
||||
github.com/insomniacslk/dhcp v0.0.0-20211209223715-7d93572ebe8e/go.mod h1:h+MxyHxRg9NH3terB1nfRIUaQEcI0XOVkdR9LNBlp8E=
|
||||
github.com/josharian/native v1.0.0 h1:Ts/E8zCSEsG17dUqv7joXJFybuMLjQfWE04tsBODTxk=
|
||||
github.com/insomniacslk/dhcp v0.0.0-20221215072855-de60144f33f8 h1:Z72DOke2yOK0Ms4Z2LK1E1OrRJXOxSj5DllTz2FYTRg=
|
||||
github.com/insomniacslk/dhcp v0.0.0-20221215072855-de60144f33f8/go.mod h1:m5WMe03WCvWcXjRnhvaAbAAXdCnu20J5P+mmH44ZzpE=
|
||||
github.com/josharian/native v1.0.0/go.mod h1:7X/raswPFr05uY3HiLlYeyQntB6OO7E/d2Cu7qoaN2w=
|
||||
github.com/josharian/native v1.0.1-0.20221213033349-c1e37c09b531/go.mod h1:7X/raswPFr05uY3HiLlYeyQntB6OO7E/d2Cu7qoaN2w=
|
||||
github.com/josharian/native v1.1.1-0.20230202152459-5c7d0dd6ab86 h1:elKwZS1OcdQ0WwEDBeqxKwb7WB62QX8bvZ/FJnVXIfk=
|
||||
github.com/josharian/native v1.1.1-0.20230202152459-5c7d0dd6ab86/go.mod h1:aFAMtuldEgx/4q7iSGazk22+IcgvtiC+HIimFO9XlS8=
|
||||
github.com/jsimonetti/rtnetlink v0.0.0-20190606172950-9527aa82566a/go.mod h1:Oz+70psSo5OFh8DBl0Zv2ACw7Esh6pPUphlvZG9x7uw=
|
||||
github.com/jsimonetti/rtnetlink v0.0.0-20200117123717-f846d4f6c1f4/go.mod h1:WGuG/smIU4J/54PblvSbh+xvCZmpJnFgr3ds6Z55XMQ=
|
||||
github.com/jsimonetti/rtnetlink v0.0.0-20201009170750-9c6f07d100c1/go.mod h1:hqoO/u39cqLeBLebZ8fWdE96O7FxrAsRYhnVOdgHxok=
|
||||
@@ -202,14 +208,15 @@ github.com/mdlayher/netlink v0.0.0-20190409211403-11939a169225/go.mod h1:eQB3mZE
|
||||
github.com/mdlayher/netlink v1.0.0/go.mod h1:KxeJAFOFLG6AjpyDkQ/iIhxygIUKD+vcwqcnu43w/+M=
|
||||
github.com/mdlayher/netlink v1.1.0/go.mod h1:H4WCitaheIsdF9yOYu8CFmCgQthAPIWZmcKp9uZHgmY=
|
||||
github.com/mdlayher/netlink v1.1.1/go.mod h1:WTYpFb/WTvlRJAyKhZL5/uy69TDDpHHu2VZmb2XgV7o=
|
||||
github.com/mdlayher/netlink v1.6.0 h1:rOHX5yl7qnlpiVkFWoqccueppMtXzeziFjWAjLg6sz0=
|
||||
github.com/mdlayher/netlink v1.6.0/go.mod h1:0o3PlBmGst1xve7wQ7j/hwpNaFaH4qCRyWCdcZk8/vA=
|
||||
github.com/mdlayher/netlink v1.7.1 h1:FdUaT/e33HjEXagwELR8R3/KL1Fq5x3G5jgHLp/BTmg=
|
||||
github.com/mdlayher/netlink v1.7.1/go.mod h1:nKO5CSjE/DJjVhk/TNp6vCE1ktVxEA8VEh8drhZzxsQ=
|
||||
github.com/mdlayher/raw v0.0.0-20190606142536-fef19f00fc18/go.mod h1:7EpbotpCmVZcu+KCX4g9WaRNuu11uyhiW7+Le1dKawg=
|
||||
github.com/mdlayher/raw v0.0.0-20191009151244-50f2db8cc065 h1:aFkJ6lx4FPip+S+Uw4aTegFMct9shDvP+79PsSxpm3w=
|
||||
github.com/mdlayher/raw v0.0.0-20191009151244-50f2db8cc065/go.mod h1:7EpbotpCmVZcu+KCX4g9WaRNuu11uyhiW7+Le1dKawg=
|
||||
github.com/mdlayher/socket v0.1.1/go.mod h1:mYV5YIZAfHh4dzDVzI8x8tWLWCliuX8Mon5Awbj+qDs=
|
||||
github.com/mdlayher/socket v0.2.3 h1:XZA2X2TjdOwNoNPVPclRCURoX/hokBY8nkTmRZFEheM=
|
||||
github.com/mdlayher/socket v0.2.3/go.mod h1:bz12/FozYNH/VbvC3q7TRIK/Y6dH1kCKsXaUeXi/FmY=
|
||||
github.com/mdlayher/socket v0.4.0 h1:280wsy40IC9M9q1uPGcLBwXpcTQDtoGwVt+BNoITxIw=
|
||||
github.com/mdlayher/socket v0.4.0/go.mod h1:xxFqz5GRCUN3UEOm9CZqEJsAbe1C8OwSK46NlmWuVoc=
|
||||
github.com/miekg/dns v1.1.50 h1:DQUfb9uc6smULcREF09Uc+/Gd46YWqJd5DbpPE9xkcA=
|
||||
github.com/miekg/dns v1.1.50/go.mod h1:e3IlAVfNqAllflbibAZEWOXOQ+Ynzk/dDozDxY7XnME=
|
||||
github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY=
|
||||
@@ -242,9 +249,7 @@ github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTE
|
||||
github.com/rogpeppe/go-internal v1.8.0/go.mod h1:WmiCO8CzOY8rg0OYDC4/i/2WRWAB6poM+XZ2dLUbcbE=
|
||||
github.com/rogpeppe/go-internal v1.8.1-0.20211023094830-115ce09fd6b4 h1:Ha8xCaq6ln1a+R91Km45Oq6lPXj2Mla6CRJYcuV2h1w=
|
||||
github.com/rogpeppe/go-internal v1.8.1-0.20211023094830-115ce09fd6b4/go.mod h1:JeRgkft04UBgHMgCIwADu4Pn6Mtm5d4nPKWu0nJ5d+o=
|
||||
github.com/rs/xid v1.4.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg=
|
||||
github.com/rs/zerolog v1.28.0 h1:MirSo27VyNi7RJYP3078AA1+Cyzd2GB66qy3aUHvsWY=
|
||||
github.com/rs/zerolog v1.28.0/go.mod h1:NILgTygv/Uej1ra5XxGf82ZFSLk58MFGAUS2o6usyD0=
|
||||
github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg=
|
||||
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
|
||||
github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d/go.mod h1:OnSkiWE9lh6wB0YB77sQom3nweQdgAjqCqsofrRNTgc=
|
||||
github.com/smartystreets/goconvey v1.6.4/go.mod h1:syvi0/a8iFYH4r/RixwvyeAJjdLS9QV7WQ/tjFTllLA=
|
||||
@@ -275,9 +280,13 @@ github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKs
|
||||
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
|
||||
github.com/subosito/gotenv v1.4.1 h1:jyEFiXpy21Wm81FBN71l9VoMMV8H8jG+qIK3GCpY6Qs=
|
||||
github.com/subosito/gotenv v1.4.1/go.mod h1:ayKnFf/c6rvx/2iiLrJUk1e6plDbT3edrFNGqEflhK0=
|
||||
github.com/u-root/uio v0.0.0-20210528114334-82958018845c/go.mod h1:LpEX5FO/cB+WF4TYGY1V5qktpaZLkKkSegbr0V4eYXA=
|
||||
github.com/u-root/uio v0.0.0-20220204230159-dac05f7d2cb4 h1:hl6sK6aFgTLISijk6xIzeqnPzQcsLqqvL6vEfTPinME=
|
||||
github.com/u-root/uio v0.0.0-20220204230159-dac05f7d2cb4/go.mod h1:LpEX5FO/cB+WF4TYGY1V5qktpaZLkKkSegbr0V4eYXA=
|
||||
github.com/u-root/uio v0.0.0-20221213070652-c3537552635f h1:dpx1PHxYqAnXzbryJrWP1NQLzEjwcVgFLhkknuFQ7ww=
|
||||
github.com/u-root/uio v0.0.0-20221213070652-c3537552635f/go.mod h1:IogEAUBXDEwX7oR/BMmCctShYs80ql4hF0ySdzGxf7E=
|
||||
github.com/vishvananda/netlink v1.1.1-0.20211118161826-650dca95af54 h1:8mhqcHPqTMhSPoslhGYihEgSfc77+7La1P6kiB6+9So=
|
||||
github.com/vishvananda/netlink v1.1.1-0.20211118161826-650dca95af54/go.mod h1:twkDnbuQxJYemMlGd4JFIcuhgX83tXhKS2B/PRMpOho=
|
||||
github.com/vishvananda/netns v0.0.0-20200728191858-db3c7e526aae/go.mod h1:DD4vA1DwXk04H54A1oHXtwZmA0grkVMdPxx/VGLCah0=
|
||||
github.com/vishvananda/netns v0.0.0-20211101163701-50045581ed74 h1:gga7acRE695APm9hlsSMoOoE65U4/TcqNj90mc69Rlg=
|
||||
github.com/vishvananda/netns v0.0.0-20211101163701-50045581ed74/go.mod h1:DD4vA1DwXk04H54A1oHXtwZmA0grkVMdPxx/VGLCah0=
|
||||
github.com/yuin/goldmark v1.1.25/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
|
||||
github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
|
||||
github.com/yuin/goldmark v1.1.32/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
|
||||
@@ -299,8 +308,8 @@ golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPh
|
||||
golang.org/x/crypto v0.0.0-20210421170649-83a5a9bb288b/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4=
|
||||
golang.org/x/crypto v0.0.0-20211108221036-ceb1ce70b4fa/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
|
||||
golang.org/x/crypto v0.0.0-20211215153901-e495a2d5b3d3/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
|
||||
golang.org/x/crypto v0.4.0 h1:UVQgzMY87xqpKNgb+kDsll2Igd33HszWHFLmpaRMq/8=
|
||||
golang.org/x/crypto v0.4.0/go.mod h1:3quD/ATkf6oY+rnes5c3ExXTbLc8mueNue5/DoinL80=
|
||||
golang.org/x/crypto v0.6.0 h1:qfktjS5LUO+fFKeJXZ+ikTRijMmljikvG68fpMMruSc=
|
||||
golang.org/x/crypto v0.6.0/go.mod h1:OFC/31mSvZgRz0V1QTNCzfAI1aIRzbiufJtkMIlEp58=
|
||||
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||
golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||
golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8=
|
||||
@@ -337,8 +346,8 @@ golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
|
||||
golang.org/x/mod v0.4.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
|
||||
golang.org/x/mod v0.4.1/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
|
||||
golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
|
||||
golang.org/x/mod v0.6.0 h1:b9gGHsz9/HhJ3HF5DHQytPpuwocVTChQJK3AvoLRD5I=
|
||||
golang.org/x/mod v0.6.0/go.mod h1:4mET923SAdbXp2ki8ey+zGs1SLqsuM2Y0uvdZR/fUNI=
|
||||
golang.org/x/mod v0.7.0 h1:LapD9S96VoQRhi/GrNTqeBJFrUjs5UHCAtTlgwA5oZA=
|
||||
golang.org/x/mod v0.7.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
|
||||
golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||
golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||
golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||
@@ -425,6 +434,7 @@ golang.org/x/sys v0.0.0-20200113162924-86b910548bc1/go.mod h1:h1NjWce9XRLGQEsW7w
|
||||
golang.org/x/sys v0.0.0-20200122134326-e047566fdf82/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20200202164722-d101bd2416d5/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20200212091648-12a6c2dcc1e4/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20200217220822-9197077df867/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20200223170610-d5e6a3e2c0ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20200302150141-5c8b2ff67527/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
@@ -433,6 +443,7 @@ golang.org/x/sys v0.0.0-20200501052902-10377860bb8e/go.mod h1:h1NjWce9XRLGQEsW7w
|
||||
golang.org/x/sys v0.0.0-20200511232937-7e40ca221e25/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20200515095857-1151b9dac4a9/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20200523222454-059865788121/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20200728102440-3e129f6d46b1/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20200803210538-64077c9b5642/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20200905004654-be1d3432aa8f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
@@ -448,7 +459,6 @@ golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7w
|
||||
golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210423185535-09eb48e85fd7/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20210525143221-35b2ab0089ea/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20210806184541-e5e7981a1069/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
@@ -456,7 +466,9 @@ golang.org/x/sys v0.0.0-20210906170528-6f6e22806c34/go.mod h1:oPkhp1MJrh7nUepCBc
|
||||
golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220128215802-99c3d69c2c27/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220622161953-175b2fd9d664/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220908164124-27713097b956/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.4.1-0.20230131160137-e7d7f63158de/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.5.0 h1:MUK/U/4lj1t1oPg0HfuXDN/Z1wv31ZJ/YcPiGccS4DU=
|
||||
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||
@@ -524,8 +536,8 @@ golang.org/x/tools v0.0.0-20210108195828-e2f9c7f1fc8e/go.mod h1:emZCQorbCU4vsT4f
|
||||
golang.org/x/tools v0.1.0/go.mod h1:xkSsbof2nBLbhDlRMhhhyNLN/zl3eTqcnHD5viDpcZ0=
|
||||
golang.org/x/tools v0.1.1/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk=
|
||||
golang.org/x/tools v0.1.6-0.20210726203631-07bc1bf47fb2/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk=
|
||||
golang.org/x/tools v0.2.0 h1:G6AHpWxTMGY1KyEYoAQ5WTtIekUUvDNjan3ugu60JvE=
|
||||
golang.org/x/tools v0.2.0/go.mod h1:y4OqIKeOV/fWJetJ8bXPU1sEVniLMIyDAZWeHdV+NTA=
|
||||
golang.org/x/tools v0.4.1-0.20221208213631-3f74d914ae6d h1:9ZNWAi4CYhNv60mXGgAncgq7SGc5qa7C8VZV8Tg7Ggs=
|
||||
golang.org/x/tools v0.4.1-0.20221208213631-3f74d914ae6d/go.mod h1:UE5sM2OK9E/d67R0ANs2xJizIymRP5gJU295PvKXxjQ=
|
||||
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
@@ -645,5 +657,5 @@ honnef.co/go/tools v0.0.1-2020.1.4/go.mod h1:X/FiERA/W4tHapMX5mGpAtMSVEeEUOyHaw9
|
||||
rsc.io/binaryregexp v0.2.0/go.mod h1:qTv7/COck+e2FymRvadv62gMdZztPaShugOCi3I+8D8=
|
||||
rsc.io/quote/v3 v3.1.0/go.mod h1:yEA65RcK8LyAZtP9Kv3t0HmxON59tX3rD+tICJqUlj0=
|
||||
rsc.io/sampler v1.3.0/go.mod h1:T1hPZKmBbMNahiBKFy5HrXp6adAjACjK9JXDnKaTXpA=
|
||||
tailscale.com v1.34.1 h1:tqm9Ww4ltyYp3IPe7vCGch6tT6j5G/WXPQ6BrVZ6pdI=
|
||||
tailscale.com v1.34.1/go.mod h1:ZsBP7rjzzB2rp+UCOumr9DAe0EQ6OPivwSXcz/BrekQ=
|
||||
tailscale.com v1.38.3 h1:2aX3+u0Re8QcN6nq7zf9Aa4ZCR2Nf6Imv3isqdQrb58=
|
||||
tailscale.com v1.38.3/go.mod h1:UWLQxcd8dz+lds2I+HpfXSruHrvXM1j4zd4zdx86t7w=
|
||||
|
||||
3372
internal/certs/cacert.pem
Normal file
3372
internal/certs/cacert.pem
Normal file
File diff suppressed because it is too large
Load Diff
22
internal/certs/root_ca.go
Normal file
22
internal/certs/root_ca.go
Normal file
@@ -0,0 +1,22 @@
|
||||
package certs
|
||||
|
||||
import (
|
||||
"crypto/x509"
|
||||
_ "embed"
|
||||
"sync"
|
||||
)
|
||||
|
||||
var (
|
||||
//go:embed cacert.pem
|
||||
caRoots []byte
|
||||
caCertPoolOnce sync.Once
|
||||
caCertPool *x509.CertPool
|
||||
)
|
||||
|
||||
func CACertPool() *x509.CertPool {
|
||||
caCertPoolOnce.Do(func() {
|
||||
caCertPool = x509.NewCertPool()
|
||||
caCertPool.AppendCertsFromPEM(caRoots)
|
||||
})
|
||||
return caCertPool
|
||||
}
|
||||
27
internal/certs/root_ca_test.go
Normal file
27
internal/certs/root_ca_test.go
Normal file
@@ -0,0 +1,27 @@
|
||||
package certs
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestCACertPool(t *testing.T) {
|
||||
c := &http.Client{
|
||||
Transport: &http.Transport{
|
||||
TLSClientConfig: &tls.Config{
|
||||
RootCAs: CACertPool(),
|
||||
},
|
||||
},
|
||||
Timeout: 2 * time.Second,
|
||||
}
|
||||
resp, err := c.Get("https://freedns.controld.com/p1")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if !resp.TLS.HandshakeComplete {
|
||||
t.Error("TLS handshake is not complete")
|
||||
}
|
||||
}
|
||||
@@ -3,23 +3,33 @@ package controld
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld"
|
||||
"github.com/Control-D-Inc/ctrld/internal/certs"
|
||||
ctrldnet "github.com/Control-D-Inc/ctrld/internal/net"
|
||||
"github.com/Control-D-Inc/ctrld/internal/router"
|
||||
)
|
||||
|
||||
const (
|
||||
resolverDataURL = "https://api.controld.com/utility"
|
||||
InvalidConfigCode = 40401
|
||||
apiDomainCom = "api.controld.com"
|
||||
apiDomainDev = "api.controld.dev"
|
||||
resolverDataURLCom = "https://api.controld.com/utility"
|
||||
resolverDataURLDev = "https://api.controld.dev/utility"
|
||||
InvalidConfigCode = 40401
|
||||
)
|
||||
|
||||
// ResolverConfig represents Control D resolver data.
|
||||
type ResolverConfig struct {
|
||||
DOH string `json:"doh"`
|
||||
DOH string `json:"doh"`
|
||||
Ctrld struct {
|
||||
CustomConfig string `json:"custom_config"`
|
||||
} `json:"ctrld"`
|
||||
Exclude []string `json:"exclude"`
|
||||
}
|
||||
|
||||
@@ -46,25 +56,44 @@ type utilityRequest struct {
|
||||
}
|
||||
|
||||
// FetchResolverConfig fetch Control D config for given uid.
|
||||
func FetchResolverConfig(uid string) (*ResolverConfig, error) {
|
||||
func FetchResolverConfig(uid, version string, cdDev bool) (*ResolverConfig, error) {
|
||||
body, _ := json.Marshal(utilityRequest{UID: uid})
|
||||
req, err := http.NewRequest("POST", resolverDataURL, bytes.NewReader(body))
|
||||
apiUrl := resolverDataURLCom
|
||||
if cdDev {
|
||||
apiUrl = resolverDataURLDev
|
||||
}
|
||||
req, err := http.NewRequest("POST", apiUrl, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("http.NewRequest: %w", err)
|
||||
}
|
||||
q := req.URL.Query()
|
||||
q.Set("platform", "ctrld")
|
||||
q.Set("version", version)
|
||||
req.URL.RawQuery = q.Encode()
|
||||
req.Header.Add("Content-Type", "application/json")
|
||||
transport := http.DefaultTransport.(*http.Transport).Clone()
|
||||
transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
// We experiment hanging in TLS handshake when connecting to ControlD API
|
||||
// with ipv6. So prefer ipv4 if available.
|
||||
proto := "tcp6"
|
||||
if ctrldnet.SupportsIPv4() {
|
||||
proto = "tcp4"
|
||||
apiDomain := apiDomainCom
|
||||
if cdDev {
|
||||
apiDomain = apiDomainDev
|
||||
}
|
||||
return ctrldnet.Dialer.DialContext(ctx, proto, addr)
|
||||
ips := ctrld.LookupIP(apiDomain)
|
||||
if len(ips) == 0 {
|
||||
ctrld.ProxyLog.Warn().Msgf("No IPs found for %s, connecting to %s", apiDomain, addr)
|
||||
return ctrldnet.Dialer.DialContext(ctx, network, addr)
|
||||
}
|
||||
ctrld.ProxyLog.Debug().Msgf("API IPs: %v", ips)
|
||||
_, port, _ := net.SplitHostPort(addr)
|
||||
addrs := make([]string, len(ips))
|
||||
for i := range ips {
|
||||
addrs[i] = net.JoinHostPort(ips[i], port)
|
||||
}
|
||||
d := &ctrldnet.ParallelDialer{}
|
||||
return d.DialContext(ctx, network, addrs)
|
||||
}
|
||||
|
||||
if router.Name() == router.DDWrt {
|
||||
transport.TLSClientConfig = &tls.Config{RootCAs: certs.CACertPool()}
|
||||
}
|
||||
client := http.Client{
|
||||
Timeout: 10 * time.Second,
|
||||
|
||||
@@ -9,22 +9,22 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
const utilityURL = "https://api.controld.com/utility"
|
||||
|
||||
func TestFetchResolverConfig(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
uid string
|
||||
dev bool
|
||||
wantErr bool
|
||||
}{
|
||||
{"valid", "p2", false},
|
||||
{"invalid uid", "abcd1234", true},
|
||||
{"valid com", "p2", false, false},
|
||||
{"valid dev", "p2", true, false},
|
||||
{"invalid uid", "abcd1234", false, true},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
got, err := FetchResolverConfig(tc.uid)
|
||||
got, err := FetchResolverConfig(tc.uid, "dev-test", tc.dev)
|
||||
require.False(t, (err != nil) != tc.wantErr, err)
|
||||
if !tc.wantErr {
|
||||
assert.NotEmpty(t, got.DOH)
|
||||
|
||||
@@ -14,8 +14,8 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/godbus/dbus/v5"
|
||||
"github.com/josharian/native"
|
||||
"tailscale.com/util/dnsname"
|
||||
"tailscale.com/util/endian"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -131,7 +131,7 @@ func (m *nmManager) trySet(ctx context.Context, config OSConfig) error {
|
||||
for _, ip := range config.Nameservers {
|
||||
b := ip.As16()
|
||||
if ip.Is4() {
|
||||
dnsv4 = append(dnsv4, endian.Native.Uint32(b[12:]))
|
||||
dnsv4 = append(dnsv4, native.Endian.Uint32(b[12:]))
|
||||
} else {
|
||||
dnsv6 = append(dnsv6, b[:])
|
||||
}
|
||||
|
||||
@@ -15,7 +15,6 @@ import (
|
||||
"strings"
|
||||
|
||||
"tailscale.com/util/dnsname"
|
||||
"tailscale.com/util/strs"
|
||||
)
|
||||
|
||||
// Path is the canonical location of resolv.conf.
|
||||
@@ -63,7 +62,7 @@ func Parse(r io.Reader) (*Config, error) {
|
||||
line, _, _ = strings.Cut(line, "#") // remove any comments
|
||||
line = strings.TrimSpace(line)
|
||||
|
||||
if s, ok := strs.CutPrefix(line, "nameserver"); ok {
|
||||
if s, ok := strings.CutPrefix(line, "nameserver"); ok {
|
||||
nameserver := strings.TrimSpace(s)
|
||||
if len(nameserver) == len(s) {
|
||||
return nil, fmt.Errorf("missing space after \"nameserver\" in %q", line)
|
||||
@@ -76,7 +75,7 @@ func Parse(r io.Reader) (*Config, error) {
|
||||
continue
|
||||
}
|
||||
|
||||
if s, ok := strs.CutPrefix(line, "search"); ok {
|
||||
if s, ok := strings.CutPrefix(line, "search"); ok {
|
||||
domains := strings.TrimSpace(s)
|
||||
if len(domains) == len(s) {
|
||||
// No leading space?!
|
||||
|
||||
@@ -2,6 +2,7 @@ package net
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
@@ -12,7 +13,6 @@ import (
|
||||
|
||||
const (
|
||||
controldIPv6Test = "ipv6.controld.io"
|
||||
controldIPv4Test = "ipv4.controld.io"
|
||||
bootstrapDNS = "76.76.2.0:53"
|
||||
)
|
||||
|
||||
@@ -37,8 +37,6 @@ var probeStackDialer = &net.Dialer{
|
||||
|
||||
var (
|
||||
stackOnce atomic.Pointer[sync.Once]
|
||||
ipv4Enabled bool
|
||||
ipv6Enabled bool
|
||||
canListenIPv6Local bool
|
||||
hasNetworkUp bool
|
||||
)
|
||||
@@ -47,13 +45,8 @@ func init() {
|
||||
stackOnce.Store(new(sync.Once))
|
||||
}
|
||||
|
||||
func supportIPv4() bool {
|
||||
_, err := probeStackDialer.Dial("tcp4", net.JoinHostPort(controldIPv4Test, "80"))
|
||||
return err == nil
|
||||
}
|
||||
|
||||
func supportIPv6(ctx context.Context) bool {
|
||||
_, err := probeStackDialer.DialContext(ctx, "tcp6", net.JoinHostPort(controldIPv6Test, "80"))
|
||||
_, err := probeStackDialer.DialContext(ctx, "tcp6", net.JoinHostPort(controldIPv6Test, "443"))
|
||||
return err == nil
|
||||
}
|
||||
|
||||
@@ -66,7 +59,7 @@ func supportListenIPv6Local() bool {
|
||||
}
|
||||
|
||||
func probeStack() {
|
||||
b := backoff.NewBackoff("probeStack", func(format string, args ...any) {}, time.Minute)
|
||||
b := backoff.NewBackoff("probeStack", func(format string, args ...any) {}, 5*time.Second)
|
||||
for {
|
||||
if _, err := probeStackDialer.Dial("udp", bootstrapDNS); err == nil {
|
||||
hasNetworkUp = true
|
||||
@@ -75,8 +68,6 @@ func probeStack() {
|
||||
b.BackOff(context.Background(), err)
|
||||
}
|
||||
}
|
||||
ipv4Enabled = supportIPv4()
|
||||
ipv6Enabled = supportIPv6(context.Background())
|
||||
canListenIPv6Local = supportListenIPv6Local()
|
||||
}
|
||||
|
||||
@@ -85,16 +76,6 @@ func Up() bool {
|
||||
return hasNetworkUp
|
||||
}
|
||||
|
||||
func SupportsIPv4() bool {
|
||||
stackOnce.Load().Do(probeStack)
|
||||
return ipv4Enabled
|
||||
}
|
||||
|
||||
func SupportsIPv6() bool {
|
||||
stackOnce.Load().Do(probeStack)
|
||||
return ipv6Enabled
|
||||
}
|
||||
|
||||
func SupportsIPv6ListenLocal() bool {
|
||||
stackOnce.Load().Do(probeStack)
|
||||
return canListenIPv6Local
|
||||
@@ -112,3 +93,47 @@ func IsIPv6(ip string) bool {
|
||||
parsedIP := net.ParseIP(ip)
|
||||
return parsedIP != nil && parsedIP.To4() == nil && parsedIP.To16() != nil
|
||||
}
|
||||
|
||||
type parallelDialerResult struct {
|
||||
conn net.Conn
|
||||
err error
|
||||
}
|
||||
|
||||
type ParallelDialer struct {
|
||||
net.Dialer
|
||||
}
|
||||
|
||||
func (d *ParallelDialer) DialContext(ctx context.Context, network string, addrs []string) (net.Conn, error) {
|
||||
if len(addrs) == 0 {
|
||||
return nil, errors.New("empty addresses")
|
||||
}
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
ch := make(chan *parallelDialerResult, len(addrs))
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(len(addrs))
|
||||
go func() {
|
||||
wg.Wait()
|
||||
close(ch)
|
||||
}()
|
||||
|
||||
for _, addr := range addrs {
|
||||
go func(addr string) {
|
||||
defer wg.Done()
|
||||
conn, err := d.Dialer.DialContext(ctx, network, addr)
|
||||
ch <- ¶llelDialerResult{conn: conn, err: err}
|
||||
}(addr)
|
||||
}
|
||||
|
||||
errs := make([]error, 0, len(addrs))
|
||||
for res := range ch {
|
||||
if res.err == nil {
|
||||
cancel()
|
||||
return res.conn, res.err
|
||||
}
|
||||
errs = append(errs, res.err)
|
||||
}
|
||||
|
||||
return nil, errors.Join(errs...)
|
||||
}
|
||||
|
||||
193
internal/router/client_info.go
Normal file
193
internal/router/client_info.go
Normal file
@@ -0,0 +1,193 @@
|
||||
package router
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/fsnotify/fsnotify"
|
||||
"tailscale.com/util/lineread"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld"
|
||||
)
|
||||
|
||||
// readClientInfoFunc represents the function for reading client info.
|
||||
type readClientInfoFunc func(name string) error
|
||||
|
||||
// clientInfoFiles specifies client info files and how to read them on supported platforms.
|
||||
var clientInfoFiles = map[string]readClientInfoFunc{
|
||||
"/tmp/dnsmasq.leases": dnsmasqReadClientInfoFile, // ddwrt
|
||||
"/tmp/dhcp.leases": dnsmasqReadClientInfoFile, // openwrt
|
||||
"/var/lib/misc/dnsmasq.leases": dnsmasqReadClientInfoFile, // merlin
|
||||
"/mnt/data/udapi-config/dnsmasq.lease": dnsmasqReadClientInfoFile, // UDM Pro
|
||||
"/data/udapi-config/dnsmasq.lease": dnsmasqReadClientInfoFile, // UDR
|
||||
"/etc/dhcpd/dhcpd-leases.log": dnsmasqReadClientInfoFile, // Synology
|
||||
"/tmp/var/lib/misc/dnsmasq.leases": dnsmasqReadClientInfoFile, // Tomato
|
||||
"/run/dnsmasq-dhcp.leases": dnsmasqReadClientInfoFile, // EdgeOS
|
||||
"/run/dhcpd.leases": iscDHCPReadClientInfoFile, // EdgeOS
|
||||
"/var/dhcpd/var/db/dhcpd.leases": iscDHCPReadClientInfoFile, // Pfsense
|
||||
}
|
||||
|
||||
// watchClientInfoTable watches changes happens in dnsmasq/dhcpd
|
||||
// lease files, perform updating to mac table if necessary.
|
||||
func (r *router) watchClientInfoTable() {
|
||||
if r.watcher == nil {
|
||||
return
|
||||
}
|
||||
timer := time.NewTicker(time.Minute * 5)
|
||||
for {
|
||||
select {
|
||||
case <-timer.C:
|
||||
for _, name := range r.watcher.WatchList() {
|
||||
_ = clientInfoFiles[name](name)
|
||||
}
|
||||
case event, ok := <-r.watcher.Events:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if event.Has(fsnotify.Write) {
|
||||
readFunc := clientInfoFiles[event.Name]
|
||||
if readFunc == nil {
|
||||
log.Println("unknown file format:", event.Name)
|
||||
continue
|
||||
}
|
||||
if err := readFunc(event.Name); err != nil && !os.IsNotExist(err) {
|
||||
log.Println("could not read client info file:", err)
|
||||
}
|
||||
}
|
||||
case err, ok := <-r.watcher.Errors:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
log.Println("error:", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Stop performs tasks need to be done before the router stopped.
|
||||
func Stop() error {
|
||||
if Name() == "" {
|
||||
return nil
|
||||
}
|
||||
r := routerPlatform.Load()
|
||||
if r.watcher != nil {
|
||||
if err := r.watcher.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetClientInfoByMac returns ClientInfo for the client associated with the given mac.
|
||||
func GetClientInfoByMac(mac string) *ctrld.ClientInfo {
|
||||
if mac == "" {
|
||||
return nil
|
||||
}
|
||||
_ = Name()
|
||||
r := routerPlatform.Load()
|
||||
val, ok := r.mac.Load(mac)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
return val.(*ctrld.ClientInfo)
|
||||
}
|
||||
|
||||
// dnsmasqReadClientInfoFile populates mac table with client info reading from dnsmasq lease file.
|
||||
func dnsmasqReadClientInfoFile(name string) error {
|
||||
f, err := os.Open(name)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer f.Close()
|
||||
return dnsmasqReadClientInfoReader(f)
|
||||
|
||||
}
|
||||
|
||||
// dnsmasqReadClientInfoReader likes dnsmasqReadClientInfoFile, but reading from an io.Reader instead of file.
|
||||
func dnsmasqReadClientInfoReader(reader io.Reader) error {
|
||||
r := routerPlatform.Load()
|
||||
return lineread.Reader(reader, func(line []byte) error {
|
||||
fields := bytes.Fields(line)
|
||||
if len(fields) < 4 {
|
||||
return nil
|
||||
}
|
||||
mac := string(fields[1])
|
||||
if _, err := net.ParseMAC(mac); err != nil {
|
||||
// The second field is not a mac, skip.
|
||||
return nil
|
||||
}
|
||||
ip := normalizeIP(string(fields[2]))
|
||||
if net.ParseIP(ip) == nil {
|
||||
log.Printf("invalid ip address entry: %q", ip)
|
||||
ip = ""
|
||||
}
|
||||
hostname := string(fields[3])
|
||||
r.mac.Store(mac, &ctrld.ClientInfo{Mac: mac, IP: ip, Hostname: hostname})
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// iscDHCPReadClientInfoFile populates mac table with client info reading from isc-dhcpd lease file.
|
||||
func iscDHCPReadClientInfoFile(name string) error {
|
||||
f, err := os.Open(name)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer f.Close()
|
||||
return iscDHCPReadClientInfoReader(f)
|
||||
}
|
||||
|
||||
// iscDHCPReadClientInfoReader likes iscDHCPReadClientInfoFile, but reading from an io.Reader instead of file.
|
||||
func iscDHCPReadClientInfoReader(reader io.Reader) error {
|
||||
r := routerPlatform.Load()
|
||||
s := bufio.NewScanner(reader)
|
||||
var ip, mac, hostname string
|
||||
for s.Scan() {
|
||||
line := s.Text()
|
||||
if strings.HasPrefix(line, "}") {
|
||||
if mac != "" {
|
||||
r.mac.Store(mac, &ctrld.ClientInfo{Mac: mac, IP: ip, Hostname: hostname})
|
||||
ip, mac, hostname = "", "", ""
|
||||
}
|
||||
continue
|
||||
}
|
||||
fields := strings.Fields(line)
|
||||
if len(fields) < 2 {
|
||||
continue
|
||||
}
|
||||
switch fields[0] {
|
||||
case "lease":
|
||||
ip = normalizeIP(strings.ToLower(fields[1]))
|
||||
if net.ParseIP(ip) == nil {
|
||||
log.Printf("invalid ip address entry: %q", ip)
|
||||
ip = ""
|
||||
}
|
||||
case "hardware":
|
||||
if len(fields) >= 3 {
|
||||
mac = strings.ToLower(strings.TrimRight(fields[2], ";"))
|
||||
if _, err := net.ParseMAC(mac); err != nil {
|
||||
// Invalid mac, skip.
|
||||
mac = ""
|
||||
}
|
||||
}
|
||||
case "client-hostname":
|
||||
hostname = strings.Trim(fields[1], `";`)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// normalizeIP normalizes the ip parsed from dnsmasq/dhcpd lease file.
|
||||
func normalizeIP(in string) string {
|
||||
// dnsmasq may put ip with interface index in lease file, strip it here.
|
||||
ip, _, found := strings.Cut(in, "%")
|
||||
if found {
|
||||
return ip
|
||||
}
|
||||
return in
|
||||
}
|
||||
107
internal/router/client_info_test.go
Normal file
107
internal/router/client_info_test.go
Normal file
@@ -0,0 +1,107 @@
|
||||
package router
|
||||
|
||||
import (
|
||||
"io"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld"
|
||||
)
|
||||
|
||||
func Test_normalizeIP(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
in string
|
||||
want string
|
||||
}{
|
||||
{"v4", "127.0.0.1", "127.0.0.1"},
|
||||
{"v4 with index", "127.0.0.1%lo", "127.0.0.1"},
|
||||
{"v6", "fe80::1", "fe80::1"},
|
||||
{"v6 with index", "fe80::1%22002", "fe80::1"},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
if got := normalizeIP(tc.in); got != tc.want {
|
||||
t.Errorf("normalizeIP() = %v, want %v", got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_readClientInfoReader(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
in string
|
||||
readFunc func(r io.Reader) error
|
||||
mac string
|
||||
}{
|
||||
{
|
||||
"good dnsmasq",
|
||||
`1683329857 e6:20:59:b8:c1:6d 192.168.1.186 * 01:e6:20:59:b8:c1:6d
|
||||
`,
|
||||
dnsmasqReadClientInfoReader,
|
||||
"e6:20:59:b8:c1:6d",
|
||||
},
|
||||
{
|
||||
"bad dnsmasq seen on UDMdream machine",
|
||||
`1683329857 e6:20:59:b8:c1:6e 192.168.1.111 * 01:e6:20:59:b8:c1:6e
|
||||
duid 00:01:00:01:2b:e4:2e:2c:52:52:14:26:dc:1c
|
||||
1683322985 117442354 2600:4040:b0e6:b700::111 ASDASD 00:01:00:01:2a:d0:b9:81:00:07:32:4c:1c:07
|
||||
`,
|
||||
dnsmasqReadClientInfoReader,
|
||||
"e6:20:59:b8:c1:6e",
|
||||
},
|
||||
{
|
||||
"isc-dhcpd good",
|
||||
`lease 192.168.1.1 {
|
||||
hardware ethernet 00:00:00:00:00:01;
|
||||
client-hostname "host-1";
|
||||
}
|
||||
`,
|
||||
iscDHCPReadClientInfoReader,
|
||||
"00:00:00:00:00:01",
|
||||
},
|
||||
{
|
||||
"isc-dhcpd bad mac",
|
||||
`lease 192.168.1.1 {
|
||||
hardware ethernet invalid-mac;
|
||||
client-hostname "host-1";
|
||||
}
|
||||
|
||||
lease 192.168.1.2 {
|
||||
hardware ethernet 00:00:00:00:00:02;
|
||||
client-hostname "host-2";
|
||||
}
|
||||
`,
|
||||
iscDHCPReadClientInfoReader,
|
||||
"00:00:00:00:00:02",
|
||||
},
|
||||
{
|
||||
"",
|
||||
`1685794060 00:00:00:00:00:04 192.168.0.209 cuonglm-ThinkPad-X1-Carbon-Gen-9 00:00:00:00:00:04 9`,
|
||||
dnsmasqReadClientInfoReader,
|
||||
"00:00:00:00:00:04",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
r := routerPlatform.Load()
|
||||
r.mac.Delete(tc.mac)
|
||||
if err := tc.readFunc(strings.NewReader(tc.in)); err != nil {
|
||||
t.Errorf("readClientInfoReader() error = %v", err)
|
||||
}
|
||||
info, existed := r.mac.Load(tc.mac)
|
||||
if !existed {
|
||||
t.Error("client info missing")
|
||||
}
|
||||
if ci, ok := info.(*ctrld.ClientInfo); ok && existed && ci.Mac != tc.mac {
|
||||
t.Errorf("mac mismatched, got: %q, want: %q", ci.Mac, tc.mac)
|
||||
} else {
|
||||
t.Log(ci)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
72
internal/router/ddwrt.go
Normal file
72
internal/router/ddwrt.go
Normal file
@@ -0,0 +1,72 @@
|
||||
package router
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"os/exec"
|
||||
)
|
||||
|
||||
const (
|
||||
nvramCtrldKeyPrefix = "ctrld_"
|
||||
nvramCtrldSetupKey = "ctrld_setup"
|
||||
nvramCtrldInstallKey = "ctrld_install"
|
||||
nvramRCStartupKey = "rc_startup"
|
||||
)
|
||||
|
||||
//lint:ignore ST1005 This error is for human.
|
||||
var errDdwrtJffs2NotEnabled = errors.New(`could not install service without jffs, follow this guide to enable:
|
||||
|
||||
https://wiki.dd-wrt.com/wiki/index.php/Journalling_Flash_File_System
|
||||
`)
|
||||
|
||||
func setupDDWrt() error {
|
||||
// Already setup.
|
||||
if val, _ := nvram("get", nvramCtrldSetupKey); val == "1" {
|
||||
return nil
|
||||
}
|
||||
|
||||
data, err := dnsMasqConf()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
nvramKvMap := nvramSetupKV()
|
||||
nvramKvMap["dnsmasq_options"] = data
|
||||
if err := nvramSetKV(nvramKvMap, nvramCtrldSetupKey); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Restart dnsmasq service.
|
||||
if err := restartDNSMasq(); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func cleanupDDWrt() error {
|
||||
// Restore old configs.
|
||||
if err := nvramRestore(nvramSetupKV(), nvramCtrldSetupKey); err != nil {
|
||||
return err
|
||||
}
|
||||
// Restart dnsmasq service.
|
||||
if err := restartDNSMasq(); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func postInstallDDWrt() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func ddwrtRestartDNSMasq() error {
|
||||
if out, err := exec.Command("restart_dns").CombinedOutput(); err != nil {
|
||||
return fmt.Errorf("restart_dns: %s, %w", string(out), err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func ddwrtJff2Enabled() bool {
|
||||
out, _ := nvram("get", "enable_jffs2")
|
||||
return out == "1"
|
||||
}
|
||||
87
internal/router/dnsmasq.go
Normal file
87
internal/router/dnsmasq.go
Normal file
@@ -0,0 +1,87 @@
|
||||
package router
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"text/template"
|
||||
)
|
||||
|
||||
const dnsMasqConfigContentTmpl = `# GENERATED BY ctrld - DO NOT MODIFY
|
||||
no-resolv
|
||||
server=127.0.0.1#5354
|
||||
{{- if .SendClientInfo}}
|
||||
add-mac
|
||||
{{- end}}
|
||||
`
|
||||
|
||||
const merlinDNSMasqPostConfPath = "/jffs/scripts/dnsmasq.postconf"
|
||||
const merlinDNSMasqPostConfMarker = `# GENERATED BY ctrld - EOF`
|
||||
|
||||
const merlinDNSMasqPostConfTmpl = `# GENERATED BY ctrld - DO NOT MODIFY
|
||||
|
||||
#!/bin/sh
|
||||
|
||||
config_file="$1"
|
||||
. /usr/sbin/helper.sh
|
||||
|
||||
pid=$(cat /tmp/ctrld.pid 2>/dev/null)
|
||||
if [ -n "$pid" ] && [ -f "/proc/${pid}/cmdline" ]; then
|
||||
pc_delete "servers-file" "$config_file" # no WAN DNS settings
|
||||
pc_append "no-resolv" "$config_file" # do not read /etc/resolv.conf
|
||||
pc_append "server=127.0.0.1#5354" "$config_file" # use ctrld as upstream
|
||||
{{- if .SendClientInfo}}
|
||||
pc_append "add-mac" "$config_file" # add client mac
|
||||
{{- end}}
|
||||
pc_delete "dnssec" "$config_file" # disable DNSSEC
|
||||
pc_delete "trust-anchor=" "$config_file" # disable DNSSEC
|
||||
|
||||
# For John fork
|
||||
pc_delete "resolv-file" "$config_file" # no WAN DNS settings
|
||||
|
||||
# Change /etc/resolv.conf, which may be changed by WAN DNS setup
|
||||
pc_delete "nameserver" /etc/resolv.conf
|
||||
pc_append "nameserver 127.0.0.1" /etc/resolv.conf
|
||||
|
||||
exit 0
|
||||
fi
|
||||
`
|
||||
|
||||
func dnsMasqConf() (string, error) {
|
||||
var sb strings.Builder
|
||||
var tmplText string
|
||||
switch Name() {
|
||||
case EdgeOS, DDWrt, OpenWrt, Ubios, Synology, Tomato:
|
||||
tmplText = dnsMasqConfigContentTmpl
|
||||
case Merlin:
|
||||
tmplText = merlinDNSMasqPostConfTmpl
|
||||
}
|
||||
tmpl := template.Must(template.New("").Parse(tmplText))
|
||||
var to = &struct {
|
||||
SendClientInfo bool
|
||||
}{
|
||||
routerPlatform.Load().sendClientInfo,
|
||||
}
|
||||
if err := tmpl.Execute(&sb, to); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return sb.String(), nil
|
||||
}
|
||||
|
||||
func restartDNSMasq() error {
|
||||
switch Name() {
|
||||
case EdgeOS:
|
||||
return edgeOSRestartDNSMasq()
|
||||
case DDWrt:
|
||||
return ddwrtRestartDNSMasq()
|
||||
case Merlin:
|
||||
return merlinRestartDNSMasq()
|
||||
case OpenWrt:
|
||||
return openwrtRestartDNSMasq()
|
||||
case Ubios:
|
||||
return ubiosRestartDNSMasq()
|
||||
case Synology:
|
||||
return synologyRestartDNSMasq()
|
||||
case Tomato:
|
||||
return tomatoRestartService(tomatoDNSMasqSvcName)
|
||||
}
|
||||
panic("not supported platform")
|
||||
}
|
||||
56
internal/router/edgeos.go
Normal file
56
internal/router/edgeos.go
Normal file
@@ -0,0 +1,56 @@
|
||||
package router
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
)
|
||||
|
||||
const edgeOSDNSMasqConfigPath = "/etc/dnsmasq.d/dnsmasq-zzz-ctrld.conf"
|
||||
|
||||
func setupEdgeOS() error {
|
||||
// Disable dnsmasq as DNS server.
|
||||
dnsMasqConfigContent, err := dnsMasqConf()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := os.WriteFile(edgeOSDNSMasqConfigPath, []byte(dnsMasqConfigContent), 0600); err != nil {
|
||||
return err
|
||||
}
|
||||
// Restart dnsmasq service.
|
||||
if err := restartDNSMasq(); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func cleanupEdgeOS() error {
|
||||
// Remove the custom dnsmasq config
|
||||
if err := os.Remove(edgeOSDNSMasqConfigPath); err != nil {
|
||||
return err
|
||||
}
|
||||
// Restart dnsmasq service.
|
||||
if err := restartDNSMasq(); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func postInstallEdgeOS() error {
|
||||
// If "Content Filtering" is enabled, UniFi OS will create firewall rules to intercept all DNS queries
|
||||
// from outside, and route those queries to separated interfaces (e.g: dnsfilter-2@if79) created by UniFi OS.
|
||||
// Thus, those queries will never reach ctrld listener. UniFi OS does not provide any mechanism to toggle this
|
||||
// feature via command line, so there's nothing ctrld can do to disable this feature. For now, reporting an
|
||||
// error and guiding users to disable the feature using UniFi OS web UI.
|
||||
if contentFilteringEnabled() {
|
||||
return errContentFilteringEnabled
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func edgeOSRestartDNSMasq() error {
|
||||
if out, err := exec.Command("/etc/init.d/dnsmasq", "restart").CombinedOutput(); err != nil {
|
||||
return fmt.Errorf("edgeosRestartDNSMasq: %s, %w", string(out), err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
89
internal/router/merlin.go
Normal file
89
internal/router/merlin.go
Normal file
@@ -0,0 +1,89 @@
|
||||
package router
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"unicode"
|
||||
)
|
||||
|
||||
func setupMerlin() error {
|
||||
buf, err := os.ReadFile(merlinDNSMasqPostConfPath)
|
||||
// Already setup.
|
||||
if bytes.Contains(buf, []byte(merlinDNSMasqPostConfMarker)) {
|
||||
return nil
|
||||
}
|
||||
if err != nil && !os.IsNotExist(err) {
|
||||
return err
|
||||
}
|
||||
|
||||
merlinDNSMasqPostConf, err := dnsMasqConf()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
data := strings.Join([]string{
|
||||
merlinDNSMasqPostConf,
|
||||
"\n",
|
||||
merlinDNSMasqPostConfMarker,
|
||||
"\n",
|
||||
string(buf),
|
||||
}, "\n")
|
||||
// Write dnsmasq post conf file.
|
||||
if err := os.WriteFile(merlinDNSMasqPostConfPath, []byte(data), 0750); err != nil {
|
||||
return err
|
||||
}
|
||||
// Restart dnsmasq service.
|
||||
if err := restartDNSMasq(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := nvramSetKV(nvramSetupKV(), nvramCtrldSetupKey); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func cleanupMerlin() error {
|
||||
// Restore old configs.
|
||||
if err := nvramRestore(nvramSetupKV(), nvramCtrldSetupKey); err != nil {
|
||||
return err
|
||||
}
|
||||
buf, err := os.ReadFile(merlinDNSMasqPostConfPath)
|
||||
if err != nil && !os.IsNotExist(err) {
|
||||
return err
|
||||
}
|
||||
// Restore dnsmasq post conf file.
|
||||
if err := os.WriteFile(merlinDNSMasqPostConfPath, merlinParsePostConf(buf), 0750); err != nil {
|
||||
return err
|
||||
}
|
||||
// Restart dnsmasq service.
|
||||
if err := restartDNSMasq(); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func postInstallMerlin() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func merlinRestartDNSMasq() error {
|
||||
if out, err := exec.Command("service", "restart_dnsmasq").CombinedOutput(); err != nil {
|
||||
return fmt.Errorf("restart_dnsmasq: %s, %w", string(out), err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func merlinParsePostConf(buf []byte) []byte {
|
||||
if len(buf) == 0 {
|
||||
return nil
|
||||
}
|
||||
parts := bytes.Split(buf, []byte(merlinDNSMasqPostConfMarker))
|
||||
if len(parts) != 1 {
|
||||
return bytes.TrimLeftFunc(parts[1], unicode.IsSpace)
|
||||
}
|
||||
return buf
|
||||
}
|
||||
38
internal/router/merlin_test.go
Normal file
38
internal/router/merlin_test.go
Normal file
@@ -0,0 +1,38 @@
|
||||
package router
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func Test_merlinParsePostConf(t *testing.T) {
|
||||
origContent := "# foo"
|
||||
data := strings.Join([]string{
|
||||
merlinDNSMasqPostConfTmpl,
|
||||
"\n",
|
||||
merlinDNSMasqPostConfMarker,
|
||||
"\n",
|
||||
}, "\n")
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
data string
|
||||
expected string
|
||||
}{
|
||||
{"empty", "", ""},
|
||||
{"no ctrld", origContent, origContent},
|
||||
{"ctrld with data", data + origContent, origContent},
|
||||
{"ctrld without data", data, ""},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
//t.Parallel()
|
||||
if got := merlinParsePostConf([]byte(tc.data)); !bytes.Equal(got, []byte(tc.expected)) {
|
||||
t.Errorf("unexpected result, want: %q, got: %q", tc.expected, string(got))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
110
internal/router/nvram.go
Normal file
110
internal/router/nvram.go
Normal file
@@ -0,0 +1,110 @@
|
||||
package router
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"os/exec"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func nvram(args ...string) (string, error) {
|
||||
cmd := exec.Command("nvram", args...)
|
||||
var stdout, stderr bytes.Buffer
|
||||
cmd.Stdout = &stdout
|
||||
cmd.Stderr = &stderr
|
||||
if err := cmd.Run(); err != nil {
|
||||
return "", fmt.Errorf("%s:%w", stderr.String(), err)
|
||||
}
|
||||
return strings.TrimSpace(stdout.String()), nil
|
||||
}
|
||||
|
||||
/*
|
||||
NOTE:
|
||||
- For Openwrt, DNSSEC is not included in default dnsmasq (require dnsmasq-full).
|
||||
- For Merlin, DNSSEC is configured during postconf script (see merlinDNSMasqPostConfTmpl).
|
||||
- For Ubios UDM Pro/Dream Machine, DNSSEC is not included in their dnsmasq package:
|
||||
+https://community.ui.com/questions/Implement-DNSSEC-into-UniFi/951c72b0-4d88-4c86-9174-45417bd2f9ca
|
||||
+https://community.ui.com/questions/Enable-DNSSEC-for-Unifi-Dream-Machine-FW-updates/e68e367c-d09b-4459-9444-18908f7c1ea1
|
||||
*/
|
||||
func nvramSetupKV() map[string]string {
|
||||
switch Name() {
|
||||
case DDWrt:
|
||||
return map[string]string{
|
||||
"dns_dnsmasq": "1", // Make dnsmasq running but disable DNS ability, ctrld will replace it.
|
||||
"dnsmasq_options": "", // Configuration of dnsmasq set by ctrld, filled by setupDDWrt.
|
||||
"dns_crypt": "0", // Disable DNSCrypt.
|
||||
"dnssec": "0", // Disable DNSSEC.
|
||||
}
|
||||
case Merlin:
|
||||
return map[string]string{
|
||||
"dnspriv_enable": "0", // Ensure Merlin native DoT disabled.
|
||||
}
|
||||
case Tomato:
|
||||
return map[string]string{
|
||||
"dnsmasq_custom": "", // Configuration of dnsmasq set by ctrld, filled by setupTomato.
|
||||
"dnscrypt_proxy": "0", // Disable DNSCrypt.
|
||||
"dnssec_enable": "0", // Disable DNSSEC.
|
||||
"stubby_proxy": "0", // Disable Stubby
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func nvramInstallKV() map[string]string {
|
||||
switch Name() {
|
||||
case Tomato:
|
||||
return map[string]string{
|
||||
tomatoNvramScriptWanupKey: "", // script to start ctrld, filled by tomatoSvc.Install method.
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func nvramSetKV(m map[string]string, setupKey string) error {
|
||||
// Backup current value, store ctrld's configs.
|
||||
for key, value := range m {
|
||||
old, err := nvram("get", key)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%s: %w", old, err)
|
||||
}
|
||||
if out, err := nvram("set", nvramCtrldKeyPrefix+key+"="+old); err != nil {
|
||||
return fmt.Errorf("%s: %w", out, err)
|
||||
}
|
||||
if out, err := nvram("set", key+"="+value); err != nil {
|
||||
return fmt.Errorf("%s: %w", out, err)
|
||||
}
|
||||
}
|
||||
|
||||
if out, err := nvram("set", setupKey+"=1"); err != nil {
|
||||
return fmt.Errorf("%s: %w", out, err)
|
||||
}
|
||||
// Commit.
|
||||
if out, err := nvram("commit"); err != nil {
|
||||
return fmt.Errorf("%s: %w", out, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func nvramRestore(m map[string]string, setupKey string) error {
|
||||
// Restore old configs.
|
||||
for key := range m {
|
||||
ctrldKey := nvramCtrldKeyPrefix + key
|
||||
old, err := nvram("get", ctrldKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%s: %w", old, err)
|
||||
}
|
||||
_, _ = nvram("unset", ctrldKey)
|
||||
if out, err := nvram("set", key+"="+old); err != nil {
|
||||
return fmt.Errorf("%s: %w", out, err)
|
||||
}
|
||||
}
|
||||
|
||||
if out, err := nvram("unset", setupKey); err != nil {
|
||||
return fmt.Errorf("%s: %w", out, err)
|
||||
}
|
||||
// Commit.
|
||||
if out, err := nvram("commit"); err != nil {
|
||||
return fmt.Errorf("%s: %w", out, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
94
internal/router/openwrt.go
Normal file
94
internal/router/openwrt.go
Normal file
@@ -0,0 +1,94 @@
|
||||
package router
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var errUCIEntryNotFound = errors.New("uci: Entry not found")
|
||||
|
||||
const openwrtDNSMasqConfigPath = "/tmp/dnsmasq.d/ctrld.conf"
|
||||
|
||||
// IsGLiNet reports whether the router is an GL.iNet router.
|
||||
func IsGLiNet() bool {
|
||||
if Name() != OpenWrt {
|
||||
return false
|
||||
}
|
||||
buf, _ := os.ReadFile("/proc/version")
|
||||
// The output of /proc/version contains "(glinet@glinet)".
|
||||
return bytes.Contains(buf, []byte(" (glinet"))
|
||||
}
|
||||
|
||||
// IsOldOpenwrt reports whether the router is an "old" version of Openwrt,
|
||||
// aka versions which don't have "service" command.
|
||||
func IsOldOpenwrt() bool {
|
||||
if Name() != OpenWrt {
|
||||
return false
|
||||
}
|
||||
cmd, _ := exec.LookPath("service")
|
||||
return cmd == ""
|
||||
}
|
||||
|
||||
func setupOpenWrt() error {
|
||||
// Delete dnsmasq port if set.
|
||||
if _, err := uci("delete", "dhcp.@dnsmasq[0].port"); err != nil && !errors.Is(err, errUCIEntryNotFound) {
|
||||
return err
|
||||
}
|
||||
dnsMasqConfigContent, err := dnsMasqConf()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := os.WriteFile(openwrtDNSMasqConfigPath, []byte(dnsMasqConfigContent), 0600); err != nil {
|
||||
return err
|
||||
}
|
||||
// Commit.
|
||||
if _, err := uci("commit"); err != nil {
|
||||
return err
|
||||
}
|
||||
// Restart dnsmasq service.
|
||||
if err := restartDNSMasq(); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func cleanupOpenWrt() error {
|
||||
// Remove the custom dnsmasq config
|
||||
if err := os.Remove(openwrtDNSMasqConfigPath); err != nil {
|
||||
return err
|
||||
}
|
||||
// Restart dnsmasq service.
|
||||
if err := restartDNSMasq(); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func postInstallOpenWrt() error {
|
||||
return exec.Command("/etc/init.d/ctrld", "enable").Run()
|
||||
}
|
||||
|
||||
func uci(args ...string) (string, error) {
|
||||
cmd := exec.Command("uci", args...)
|
||||
var stdout, stderr bytes.Buffer
|
||||
cmd.Stdout = &stdout
|
||||
cmd.Stderr = &stderr
|
||||
if err := cmd.Run(); err != nil {
|
||||
if strings.HasPrefix(stderr.String(), errUCIEntryNotFound.Error()) {
|
||||
return "", errUCIEntryNotFound
|
||||
}
|
||||
return "", fmt.Errorf("%s:%w", stderr.String(), err)
|
||||
}
|
||||
return strings.TrimSpace(stdout.String()), nil
|
||||
}
|
||||
|
||||
func openwrtRestartDNSMasq() error {
|
||||
if out, err := exec.Command("/etc/init.d/dnsmasq", "restart").CombinedOutput(); err != nil {
|
||||
return fmt.Errorf("%s: %w", string(out), err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
66
internal/router/pfsense.go
Normal file
66
internal/router/pfsense.go
Normal file
@@ -0,0 +1,66 @@
|
||||
package router
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/kardianos/service"
|
||||
)
|
||||
|
||||
const (
|
||||
rcPath = "/usr/local/etc/rc.d"
|
||||
unboundRcPath = rcPath + "/unbound"
|
||||
dnsmasqRcPath = rcPath + "/dnsmasq"
|
||||
)
|
||||
|
||||
func setupPfsense() error {
|
||||
// If Pfsense is in DNS Resolver mode, ensure no unbound processes running.
|
||||
_ = exec.Command("killall", "unbound").Run()
|
||||
|
||||
// If Pfsense is in DNS Forwarder mode, ensure no dnsmasq processes running.
|
||||
_ = exec.Command("killall", "dnsmasq").Run()
|
||||
return nil
|
||||
}
|
||||
|
||||
func cleanupPfsense(svc *service.Config) error {
|
||||
if err := os.Remove(filepath.Join(rcPath, svc.Name+".sh")); err != nil {
|
||||
return fmt.Errorf("os.Remove: %w", err)
|
||||
}
|
||||
_ = exec.Command(unboundRcPath, "onerestart").Run()
|
||||
_ = exec.Command(dnsmasqRcPath, "onerestart").Run()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func postInstallPfsense(svc *service.Config) error {
|
||||
// pfsense need ".sh" extension for script to be run at boot.
|
||||
// See: https://docs.netgate.com/pfsense/en/latest/development/boot-commands.html#shell-script-option
|
||||
oldname := filepath.Join(rcPath, svc.Name)
|
||||
newname := filepath.Join(rcPath, svc.Name+".sh")
|
||||
_ = os.Remove(newname)
|
||||
if err := os.Symlink(oldname, newname); err != nil {
|
||||
return fmt.Errorf("os.Symlink: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
const pfsenseInitScript = `#!/bin/sh
|
||||
|
||||
# PROVIDE: {{.Name}}
|
||||
# REQUIRE: SERVERS
|
||||
# REQUIRE: unbound dnsmasq securelevel
|
||||
# KEYWORD: shutdown
|
||||
|
||||
. /etc/rc.subr
|
||||
|
||||
name="{{.Name}}"
|
||||
{{.Name}}_env="IS_DAEMON=1"
|
||||
pidfile="/var/run/${name}.pid"
|
||||
command="/usr/sbin/daemon"
|
||||
daemon_args="-P ${pidfile} -r -t \"${name}: daemon\"{{if .WorkingDirectory}} -c {{.WorkingDirectory}}{{end}}"
|
||||
command_args="${daemon_args} {{.Path}}{{range .Arguments}} {{.}}{{end}}"
|
||||
|
||||
run_rc_command "$1"
|
||||
`
|
||||
24
internal/router/procd.go
Normal file
24
internal/router/procd.go
Normal file
@@ -0,0 +1,24 @@
|
||||
package router
|
||||
|
||||
const openWrtScript = `#!/bin/sh /etc/rc.common
|
||||
USE_PROCD=1
|
||||
# After network starts
|
||||
START=21
|
||||
# Before network stops
|
||||
STOP=89
|
||||
cmd="{{.Path}}{{range .Arguments}} {{.|cmd}}{{end}}"
|
||||
name="{{.Name}}"
|
||||
pid_file="/var/run/${name}.pid"
|
||||
|
||||
start_service() {
|
||||
echo "Starting ${name}"
|
||||
procd_open_instance
|
||||
procd_set_param command ${cmd}
|
||||
procd_set_param respawn # respawn automatically if something died
|
||||
procd_set_param stdout 1 # forward stdout of the command to logd
|
||||
procd_set_param stderr 1 # same for stderr
|
||||
procd_set_param pidfile ${pid_file} # write a pid file on instance start and remove it on stop
|
||||
procd_close_instance
|
||||
echo "${name} has been started"
|
||||
}
|
||||
`
|
||||
255
internal/router/router.go
Normal file
255
internal/router/router.go
Normal file
@@ -0,0 +1,255 @@
|
||||
package router
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/fsnotify/fsnotify"
|
||||
"github.com/kardianos/service"
|
||||
"tailscale.com/logtail/backoff"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld"
|
||||
)
|
||||
|
||||
const (
|
||||
OpenWrt = "openwrt"
|
||||
DDWrt = "ddwrt"
|
||||
Merlin = "merlin"
|
||||
Ubios = "ubios"
|
||||
Synology = "synology"
|
||||
Tomato = "tomato"
|
||||
EdgeOS = "edgeos"
|
||||
Pfsense = "pfsense"
|
||||
)
|
||||
|
||||
// ErrNotSupported reports the current router is not supported error.
|
||||
var ErrNotSupported = errors.New("unsupported platform")
|
||||
|
||||
var routerPlatform atomic.Pointer[router]
|
||||
|
||||
type router struct {
|
||||
name string
|
||||
sendClientInfo bool
|
||||
mac sync.Map
|
||||
watcher *fsnotify.Watcher
|
||||
}
|
||||
|
||||
// IsSupported reports whether the given platform is supported by ctrld.
|
||||
func IsSupported(platform string) bool {
|
||||
switch platform {
|
||||
case EdgeOS, DDWrt, Merlin, OpenWrt, Pfsense, Synology, Tomato, Ubios:
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// SupportedPlatforms return all platforms that can be configured to run with ctrld.
|
||||
func SupportedPlatforms() []string {
|
||||
return []string{EdgeOS, DDWrt, Merlin, OpenWrt, Pfsense, Synology, Tomato, Ubios}
|
||||
}
|
||||
|
||||
var configureFunc = map[string]func() error{
|
||||
EdgeOS: setupEdgeOS,
|
||||
DDWrt: setupDDWrt,
|
||||
Merlin: setupMerlin,
|
||||
OpenWrt: setupOpenWrt,
|
||||
Pfsense: setupPfsense,
|
||||
Synology: setupSynology,
|
||||
Tomato: setupTomato,
|
||||
Ubios: setupUbiOS,
|
||||
}
|
||||
|
||||
// Configure configures things for running ctrld on the router.
|
||||
func Configure(c *ctrld.Config) error {
|
||||
name := Name()
|
||||
switch name {
|
||||
case EdgeOS, DDWrt, Merlin, OpenWrt, Pfsense, Synology, Tomato, Ubios:
|
||||
if c.HasUpstreamSendClientInfo() {
|
||||
r := routerPlatform.Load()
|
||||
r.sendClientInfo = true
|
||||
watcher, err := fsnotify.NewWatcher()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
r.watcher = watcher
|
||||
go r.watchClientInfoTable()
|
||||
for file, readClienInfoFunc := range clientInfoFiles {
|
||||
_ = readClienInfoFunc(file)
|
||||
_ = r.watcher.Add(file)
|
||||
}
|
||||
}
|
||||
configure := configureFunc[name]
|
||||
if err := configure(); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
default:
|
||||
return ErrNotSupported
|
||||
}
|
||||
}
|
||||
|
||||
// ConfigureService performs necessary setup for running ctrld as a service on router.
|
||||
func ConfigureService(sc *service.Config) error {
|
||||
name := Name()
|
||||
switch name {
|
||||
case DDWrt:
|
||||
if !ddwrtJff2Enabled() {
|
||||
return errDdwrtJffs2NotEnabled
|
||||
}
|
||||
case OpenWrt:
|
||||
sc.Option["SysvScript"] = openWrtScript
|
||||
case Pfsense:
|
||||
sc.Option["SysvScript"] = pfsenseInitScript
|
||||
case EdgeOS, Merlin, Synology, Tomato, Ubios:
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// PreRun blocks until the router is ready for running ctrld.
|
||||
func PreRun() (err error) {
|
||||
// On some routers, NTP may out of sync, so waiting for it to be ready.
|
||||
switch Name() {
|
||||
case Merlin, Tomato:
|
||||
// Wait until `ntp_ready=1` set.
|
||||
b := backoff.NewBackoff("PreStart", func(format string, args ...any) {}, 10*time.Second)
|
||||
for {
|
||||
out, err := nvram("get", "ntp_ready")
|
||||
if err != nil {
|
||||
return fmt.Errorf("PreStart: nvram: %w", err)
|
||||
}
|
||||
if out == "1" {
|
||||
return nil
|
||||
}
|
||||
b.BackOff(context.Background(), errors.New("ntp not ready"))
|
||||
}
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// PostInstall performs task after installing ctrld on router.
|
||||
func PostInstall(svc *service.Config) error {
|
||||
name := Name()
|
||||
switch name {
|
||||
case EdgeOS:
|
||||
return postInstallEdgeOS()
|
||||
case DDWrt:
|
||||
return postInstallDDWrt()
|
||||
case Merlin:
|
||||
return postInstallMerlin()
|
||||
case OpenWrt:
|
||||
return postInstallOpenWrt()
|
||||
case Pfsense:
|
||||
return postInstallPfsense(svc)
|
||||
case Synology:
|
||||
return postInstallSynology()
|
||||
case Tomato:
|
||||
return postInstallTomato()
|
||||
case Ubios:
|
||||
return postInstallUbiOS()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Cleanup cleans ctrld setup on the router.
|
||||
func Cleanup(svc *service.Config) error {
|
||||
name := Name()
|
||||
switch name {
|
||||
case EdgeOS:
|
||||
return cleanupEdgeOS()
|
||||
case DDWrt:
|
||||
return cleanupDDWrt()
|
||||
case Merlin:
|
||||
return cleanupMerlin()
|
||||
case OpenWrt:
|
||||
return cleanupOpenWrt()
|
||||
case Pfsense:
|
||||
return cleanupPfsense(svc)
|
||||
case Synology:
|
||||
return cleanupSynology()
|
||||
case Tomato:
|
||||
return cleanupTomato()
|
||||
case Ubios:
|
||||
return cleanupUbiOS()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ListenAddress returns the listener address of ctrld on router.
|
||||
func ListenAddress() string {
|
||||
name := Name()
|
||||
switch name {
|
||||
case EdgeOS, DDWrt, Merlin, OpenWrt, Synology, Tomato, Ubios:
|
||||
return "127.0.0.1:5354"
|
||||
case Pfsense:
|
||||
// On pfsense, we run ctrld as DNS resolver.
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// Name returns name of the router platform.
|
||||
func Name() string {
|
||||
if r := routerPlatform.Load(); r != nil {
|
||||
return r.name
|
||||
}
|
||||
r := &router{}
|
||||
r.name = distroName()
|
||||
routerPlatform.Store(r)
|
||||
return r.name
|
||||
}
|
||||
|
||||
func distroName() string {
|
||||
switch {
|
||||
case bytes.HasPrefix(unameO(), []byte("DD-WRT")):
|
||||
return DDWrt
|
||||
case bytes.HasPrefix(unameO(), []byte("ASUSWRT-Merlin")):
|
||||
return Merlin
|
||||
case haveFile("/etc/openwrt_version"):
|
||||
return OpenWrt
|
||||
case haveDir("/data/unifi"):
|
||||
return Ubios
|
||||
case bytes.HasPrefix(unameU(), []byte("synology")):
|
||||
return Synology
|
||||
case bytes.HasPrefix(unameO(), []byte("Tomato")):
|
||||
return Tomato
|
||||
case haveDir("/config/scripts/post-config.d"):
|
||||
return EdgeOS
|
||||
case haveFile("/etc/ubnt/init/vyatta-router"):
|
||||
return EdgeOS // For 2.x
|
||||
case isPfsense():
|
||||
return Pfsense
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func haveFile(file string) bool {
|
||||
_, err := os.Stat(file)
|
||||
return err == nil
|
||||
}
|
||||
|
||||
func haveDir(dir string) bool {
|
||||
fi, _ := os.Stat(dir)
|
||||
return fi != nil && fi.IsDir()
|
||||
}
|
||||
|
||||
func unameO() []byte {
|
||||
out, _ := exec.Command("uname", "-o").Output()
|
||||
return out
|
||||
}
|
||||
|
||||
func unameU() []byte {
|
||||
out, _ := exec.Command("uname", "-u").Output()
|
||||
return out
|
||||
}
|
||||
|
||||
func isPfsense() bool {
|
||||
b, err := os.ReadFile("/etc/platform")
|
||||
return err == nil && bytes.HasPrefix(b, []byte("pfSense"))
|
||||
}
|
||||
91
internal/router/service.go
Normal file
91
internal/router/service.go
Normal file
@@ -0,0 +1,91 @@
|
||||
package router
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"os"
|
||||
"os/exec"
|
||||
|
||||
"github.com/kardianos/service"
|
||||
)
|
||||
|
||||
func init() {
|
||||
systems := []service.System{
|
||||
&linuxSystemService{
|
||||
name: "ddwrt",
|
||||
detect: func() bool { return Name() == DDWrt },
|
||||
interactive: func() bool {
|
||||
is, _ := isInteractive()
|
||||
return is
|
||||
},
|
||||
new: newddwrtService,
|
||||
},
|
||||
&linuxSystemService{
|
||||
name: "merlin",
|
||||
detect: func() bool { return Name() == Merlin },
|
||||
interactive: func() bool {
|
||||
is, _ := isInteractive()
|
||||
return is
|
||||
},
|
||||
new: newMerlinService,
|
||||
},
|
||||
&linuxSystemService{
|
||||
name: "ubios",
|
||||
detect: func() bool {
|
||||
if Name() != Ubios {
|
||||
return false
|
||||
}
|
||||
out, err := exec.Command("ubnt-device-info", "firmware").CombinedOutput()
|
||||
if err == nil {
|
||||
// For v2/v3, UbiOS use a Debian base with systemd, so it is not
|
||||
// necessary to use custom implementation for supporting init system.
|
||||
return bytes.HasPrefix(out, []byte("1."))
|
||||
}
|
||||
return true
|
||||
},
|
||||
interactive: func() bool {
|
||||
is, _ := isInteractive()
|
||||
return is
|
||||
},
|
||||
new: newUbiosService,
|
||||
},
|
||||
&linuxSystemService{
|
||||
name: "tomato",
|
||||
detect: func() bool { return Name() == Tomato },
|
||||
interactive: func() bool {
|
||||
is, _ := isInteractive()
|
||||
return is
|
||||
},
|
||||
new: newTomatoService,
|
||||
},
|
||||
}
|
||||
systems = append(systems, service.AvailableSystems()...)
|
||||
service.ChooseSystem(systems...)
|
||||
}
|
||||
|
||||
type linuxSystemService struct {
|
||||
name string
|
||||
detect func() bool
|
||||
interactive func() bool
|
||||
new func(i service.Interface, platform string, c *service.Config) (service.Service, error)
|
||||
}
|
||||
|
||||
func (sc linuxSystemService) String() string {
|
||||
return sc.name
|
||||
}
|
||||
func (sc linuxSystemService) Detect() bool {
|
||||
return sc.detect()
|
||||
}
|
||||
func (sc linuxSystemService) Interactive() bool {
|
||||
return sc.interactive()
|
||||
}
|
||||
func (sc linuxSystemService) New(i service.Interface, c *service.Config) (service.Service, error) {
|
||||
return sc.new(i, sc.String(), c)
|
||||
}
|
||||
|
||||
func isInteractive() (bool, error) {
|
||||
ppid := os.Getppid()
|
||||
if ppid == 1 {
|
||||
return false, nil
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
292
internal/router/service_ddwrt.go
Normal file
292
internal/router/service_ddwrt.go
Normal file
@@ -0,0 +1,292 @@
|
||||
package router
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"os/signal"
|
||||
"strings"
|
||||
"syscall"
|
||||
"text/template"
|
||||
|
||||
"github.com/kardianos/service"
|
||||
)
|
||||
|
||||
type ddwrtSvc struct {
|
||||
i service.Interface
|
||||
platform string
|
||||
*service.Config
|
||||
rcStartup string
|
||||
}
|
||||
|
||||
func newddwrtService(i service.Interface, platform string, c *service.Config) (service.Service, error) {
|
||||
s := &ddwrtSvc{
|
||||
i: i,
|
||||
platform: platform,
|
||||
Config: c,
|
||||
}
|
||||
if err := os.MkdirAll("/jffs/etc/config", 0644); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (s *ddwrtSvc) String() string {
|
||||
if len(s.DisplayName) > 0 {
|
||||
return s.DisplayName
|
||||
}
|
||||
return s.Name
|
||||
}
|
||||
|
||||
func (s *ddwrtSvc) Platform() string {
|
||||
return s.platform
|
||||
}
|
||||
|
||||
func (s *ddwrtSvc) configPath() string {
|
||||
return fmt.Sprintf("/jffs/etc/config/%s.startup", s.Config.Name)
|
||||
}
|
||||
|
||||
func (s *ddwrtSvc) template() *template.Template {
|
||||
return template.Must(template.New("").Parse(ddwrtSvcScript))
|
||||
}
|
||||
|
||||
func (s *ddwrtSvc) Install() error {
|
||||
confPath := s.configPath()
|
||||
if _, err := os.Stat(confPath); err == nil {
|
||||
return fmt.Errorf("already installed: %s", confPath)
|
||||
}
|
||||
|
||||
path, err := os.Executable()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !strings.HasPrefix(path, "/jffs/") {
|
||||
return errors.New("could not install service outside /jffs")
|
||||
}
|
||||
|
||||
var to = &struct {
|
||||
*service.Config
|
||||
Path string
|
||||
}{
|
||||
s.Config,
|
||||
path,
|
||||
}
|
||||
|
||||
f, err := os.Create(confPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
if err := s.template().Execute(f, to); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = os.Chmod(confPath, 0755); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var sb strings.Builder
|
||||
if err := template.Must(template.New("").Parse(ddwrtStartupCmd)).Execute(&sb, to); err != nil {
|
||||
return err
|
||||
}
|
||||
s.rcStartup = sb.String()
|
||||
curVal, err := nvram("get", nvramRCStartupKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := nvram("set", nvramCtrldKeyPrefix+nvramRCStartupKey+"="+curVal); err != nil {
|
||||
return err
|
||||
}
|
||||
val := strings.Join([]string{curVal, s.rcStartup + " &", fmt.Sprintf(`echo $! > "/tmp/%s.pid"`, s.Config.Name)}, "\n")
|
||||
|
||||
if _, err := nvram("set", nvramRCStartupKey+"="+val); err != nil {
|
||||
return err
|
||||
}
|
||||
if out, err := nvram("commit"); err != nil {
|
||||
return fmt.Errorf("%s: %w", out, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *ddwrtSvc) Uninstall() error {
|
||||
if err := os.Remove(s.configPath()); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ctrldStartupKey := nvramCtrldKeyPrefix + nvramRCStartupKey
|
||||
rcStartup, err := nvram("get", ctrldStartupKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, _ = nvram("unset", ctrldStartupKey)
|
||||
if _, err := nvram("set", nvramRCStartupKey+"="+rcStartup); err != nil {
|
||||
return err
|
||||
}
|
||||
if out, err := nvram("commit"); err != nil {
|
||||
return fmt.Errorf("%s: %w", out, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *ddwrtSvc) Logger(errs chan<- error) (service.Logger, error) {
|
||||
if service.Interactive() {
|
||||
return service.ConsoleLogger, nil
|
||||
}
|
||||
return s.SystemLogger(errs)
|
||||
}
|
||||
|
||||
func (s *ddwrtSvc) SystemLogger(errs chan<- error) (service.Logger, error) {
|
||||
// TODO(cuonglm): detect syslog enable and return proper logger?
|
||||
// this at least works with default configuration.
|
||||
if service.Interactive() {
|
||||
return service.ConsoleLogger, nil
|
||||
|
||||
}
|
||||
return &noopLogger{}, nil
|
||||
}
|
||||
|
||||
func (s *ddwrtSvc) Run() (err error) {
|
||||
err = s.i.Start(s)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if interactice, _ := isInteractive(); !interactice {
|
||||
signal.Ignore(syscall.SIGHUP)
|
||||
}
|
||||
var sigChan = make(chan os.Signal, 1)
|
||||
signal.Notify(sigChan, syscall.SIGTERM, os.Interrupt)
|
||||
<-sigChan
|
||||
|
||||
return s.i.Stop(s)
|
||||
}
|
||||
|
||||
func (s *ddwrtSvc) Status() (service.Status, error) {
|
||||
if _, err := os.Stat(s.configPath()); os.IsNotExist(err) {
|
||||
return service.StatusUnknown, service.ErrNotInstalled
|
||||
}
|
||||
out, err := exec.Command(s.configPath(), "status").CombinedOutput()
|
||||
if err != nil {
|
||||
return service.StatusUnknown, err
|
||||
}
|
||||
switch string(bytes.TrimSpace(out)) {
|
||||
case "running":
|
||||
return service.StatusRunning, nil
|
||||
default:
|
||||
return service.StatusStopped, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ddwrtSvc) Start() error {
|
||||
return exec.Command(s.configPath(), "start").Run()
|
||||
}
|
||||
|
||||
func (s *ddwrtSvc) Stop() error {
|
||||
return exec.Command(s.configPath(), "stop").Run()
|
||||
}
|
||||
|
||||
func (s *ddwrtSvc) Restart() error {
|
||||
err := s.Stop()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return s.Start()
|
||||
}
|
||||
|
||||
type noopLogger struct {
|
||||
}
|
||||
|
||||
func (c noopLogger) Error(v ...interface{}) error {
|
||||
return nil
|
||||
}
|
||||
func (c noopLogger) Warning(v ...interface{}) error {
|
||||
return nil
|
||||
}
|
||||
func (c noopLogger) Info(v ...interface{}) error {
|
||||
return nil
|
||||
}
|
||||
func (c noopLogger) Errorf(format string, a ...interface{}) error {
|
||||
return nil
|
||||
}
|
||||
func (c noopLogger) Warningf(format string, a ...interface{}) error {
|
||||
return nil
|
||||
}
|
||||
func (c noopLogger) Infof(format string, a ...interface{}) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
const ddwrtStartupCmd = `{{.Path}}{{range .Arguments}} {{.}}{{end}}`
|
||||
const ddwrtSvcScript = `#!/bin/sh
|
||||
|
||||
name="{{.Name}}"
|
||||
cmd="{{.Path}}{{range .Arguments}} {{.}}{{end}}"
|
||||
pid_file="/tmp/$name.pid"
|
||||
|
||||
get_pid() {
|
||||
cat "$pid_file"
|
||||
}
|
||||
|
||||
is_running() {
|
||||
[ -f "$pid_file" ] && ps | grep -q "^ *$(get_pid) "
|
||||
}
|
||||
|
||||
case "$1" in
|
||||
start)
|
||||
if is_running; then
|
||||
echo "Already started"
|
||||
else
|
||||
echo "Starting $name"
|
||||
$cmd &
|
||||
echo $! > "$pid_file"
|
||||
chmod 600 "$pid_file"
|
||||
if ! is_running; then
|
||||
echo "Failed to start $name"
|
||||
exit 1
|
||||
fi
|
||||
fi
|
||||
;;
|
||||
stop)
|
||||
if is_running; then
|
||||
echo -n "Stopping $name..."
|
||||
kill "$(get_pid)"
|
||||
for _ in 1 2 3 4 5; do
|
||||
if ! is_running; then
|
||||
echo "stopped"
|
||||
if [ -f "$pid_file" ]; then
|
||||
rm "$pid_file"
|
||||
fi
|
||||
exit 0
|
||||
fi
|
||||
printf "."
|
||||
sleep 2
|
||||
done
|
||||
echo "failed to stop $name"
|
||||
exit 1
|
||||
fi
|
||||
exit 1
|
||||
;;
|
||||
restart)
|
||||
$0 stop
|
||||
$0 start
|
||||
;;
|
||||
status)
|
||||
if is_running; then
|
||||
echo "running"
|
||||
else
|
||||
echo "stopped"
|
||||
exit 1
|
||||
fi
|
||||
;;
|
||||
*)
|
||||
echo "Usage: $0 {start|stop|restart|status}"
|
||||
exit 1
|
||||
;;
|
||||
esac
|
||||
exit 0
|
||||
`
|
||||
354
internal/router/service_merlin.go
Normal file
354
internal/router/service_merlin.go
Normal file
@@ -0,0 +1,354 @@
|
||||
package router
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"os/signal"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"syscall"
|
||||
"text/template"
|
||||
|
||||
"github.com/kardianos/service"
|
||||
)
|
||||
|
||||
const (
|
||||
merlinJFFSScriptPath = "/jffs/scripts/services-start"
|
||||
merlinJFFSServiceEventScriptPath = "/jffs/scripts/service-event"
|
||||
)
|
||||
|
||||
type merlinSvc struct {
|
||||
i service.Interface
|
||||
platform string
|
||||
*service.Config
|
||||
}
|
||||
|
||||
func newMerlinService(i service.Interface, platform string, c *service.Config) (service.Service, error) {
|
||||
s := &merlinSvc{
|
||||
i: i,
|
||||
platform: platform,
|
||||
Config: c,
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (s *merlinSvc) String() string {
|
||||
if len(s.DisplayName) > 0 {
|
||||
return s.DisplayName
|
||||
}
|
||||
return s.Name
|
||||
}
|
||||
|
||||
func (s *merlinSvc) Platform() string {
|
||||
return s.platform
|
||||
}
|
||||
|
||||
func (s *merlinSvc) configPath() string {
|
||||
path, err := os.Executable()
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return path + ".startup"
|
||||
}
|
||||
|
||||
func (s *merlinSvc) template() *template.Template {
|
||||
return template.Must(template.New("").Parse(merlinSvcScript))
|
||||
}
|
||||
|
||||
func (s *merlinSvc) Install() error {
|
||||
exePath, err := os.Executable()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !strings.HasPrefix(exePath, "/jffs/") {
|
||||
return errors.New("could not install service outside /jffs")
|
||||
}
|
||||
if _, err := nvram("set", "jffs2_scripts=1"); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := nvram("commit"); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
confPath := s.configPath()
|
||||
if _, err := os.Stat(confPath); err == nil {
|
||||
return fmt.Errorf("already installed: %s", confPath)
|
||||
}
|
||||
|
||||
var to = &struct {
|
||||
*service.Config
|
||||
Path string
|
||||
}{
|
||||
s.Config,
|
||||
exePath,
|
||||
}
|
||||
|
||||
f, err := os.Create(confPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("os.Create: %w", err)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
if err := s.template().Execute(f, to); err != nil {
|
||||
return fmt.Errorf("s.template.Execute: %w", err)
|
||||
}
|
||||
|
||||
if err = os.Chmod(confPath, 0755); err != nil {
|
||||
return fmt.Errorf("os.Chmod: startup script: %w", err)
|
||||
}
|
||||
|
||||
if err := os.MkdirAll(filepath.Dir(merlinJFFSScriptPath), 0755); err != nil {
|
||||
return fmt.Errorf("os.MkdirAll: %w", err)
|
||||
}
|
||||
|
||||
tmpScript, err := os.CreateTemp("", "ctrld_install")
|
||||
if err != nil {
|
||||
return fmt.Errorf("os.CreateTemp: %w", err)
|
||||
}
|
||||
defer os.Remove(tmpScript.Name())
|
||||
defer tmpScript.Close()
|
||||
|
||||
if _, err := tmpScript.WriteString(merlinAddLineToScript); err != nil {
|
||||
return fmt.Errorf("tmpScript.WriteString: %w", err)
|
||||
}
|
||||
if err := tmpScript.Close(); err != nil {
|
||||
return fmt.Errorf("tmpScript.Close: %w", err)
|
||||
}
|
||||
addLineToScript := func(line, script string) error {
|
||||
if _, err := os.Stat(script); os.IsNotExist(err) {
|
||||
if err := os.WriteFile(script, []byte("#!/bin/sh\n"), 0755); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if err := os.Chmod(script, 0755); err != nil {
|
||||
return fmt.Errorf("os.Chmod: jffs script: %w", err)
|
||||
}
|
||||
|
||||
if err := exec.Command("sh", tmpScript.Name(), line, script).Run(); err != nil {
|
||||
return fmt.Errorf("exec.Command: add startup script: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
for script, line := range map[string]string{
|
||||
merlinJFFSScriptPath: s.configPath() + " start",
|
||||
merlinJFFSServiceEventScriptPath: s.configPath() + ` service_event "$1" "$2"`,
|
||||
} {
|
||||
if err := addLineToScript(line, script); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *merlinSvc) Uninstall() error {
|
||||
if err := os.Remove(s.configPath()); err != nil {
|
||||
return fmt.Errorf("os.Remove: %w", err)
|
||||
}
|
||||
tmpScript, err := os.CreateTemp("", "ctrld_uninstall")
|
||||
if err != nil {
|
||||
return fmt.Errorf("os.CreateTemp: %w", err)
|
||||
}
|
||||
defer os.Remove(tmpScript.Name())
|
||||
defer tmpScript.Close()
|
||||
|
||||
if _, err := tmpScript.WriteString(merlinRemoveLineFromScript); err != nil {
|
||||
return fmt.Errorf("tmpScript.WriteString: %w", err)
|
||||
}
|
||||
if err := tmpScript.Close(); err != nil {
|
||||
return fmt.Errorf("tmpScript.Close: %w", err)
|
||||
}
|
||||
removeLineFromScript := func(line, script string) error {
|
||||
if _, err := os.Stat(script); os.IsNotExist(err) {
|
||||
if err := os.WriteFile(script, []byte("#!/bin/sh\n"), 0755); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if err := os.Chmod(script, 0755); err != nil {
|
||||
return fmt.Errorf("os.Chmod: jffs script: %w", err)
|
||||
}
|
||||
|
||||
if err := exec.Command("sh", tmpScript.Name(), line, script).Run(); err != nil {
|
||||
return fmt.Errorf("exec.Command: add startup script: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
for script, line := range map[string]string{
|
||||
merlinJFFSScriptPath: s.configPath() + " start",
|
||||
merlinJFFSServiceEventScriptPath: s.configPath() + ` service_event "$1" "$2"`,
|
||||
} {
|
||||
if err := removeLineFromScript(line, script); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *merlinSvc) Logger(errs chan<- error) (service.Logger, error) {
|
||||
if service.Interactive() {
|
||||
return service.ConsoleLogger, nil
|
||||
}
|
||||
return s.SystemLogger(errs)
|
||||
}
|
||||
|
||||
func (s *merlinSvc) SystemLogger(errs chan<- error) (service.Logger, error) {
|
||||
return newSysLogger(s.Name, errs)
|
||||
}
|
||||
|
||||
func (s *merlinSvc) Run() (err error) {
|
||||
err = s.i.Start(s)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if interactice, _ := isInteractive(); !interactice {
|
||||
signal.Ignore(syscall.SIGHUP)
|
||||
}
|
||||
|
||||
var sigChan = make(chan os.Signal, 1)
|
||||
signal.Notify(sigChan, syscall.SIGTERM, os.Interrupt)
|
||||
<-sigChan
|
||||
|
||||
return s.i.Stop(s)
|
||||
}
|
||||
|
||||
func (s *merlinSvc) Status() (service.Status, error) {
|
||||
if _, err := os.Stat(s.configPath()); os.IsNotExist(err) {
|
||||
return service.StatusUnknown, service.ErrNotInstalled
|
||||
}
|
||||
out, err := exec.Command(s.configPath(), "status").CombinedOutput()
|
||||
if err != nil {
|
||||
return service.StatusUnknown, err
|
||||
}
|
||||
switch string(bytes.TrimSpace(out)) {
|
||||
case "running":
|
||||
return service.StatusRunning, nil
|
||||
default:
|
||||
return service.StatusStopped, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (s *merlinSvc) Start() error {
|
||||
return exec.Command(s.configPath(), "start").Run()
|
||||
}
|
||||
|
||||
func (s *merlinSvc) Stop() error {
|
||||
return exec.Command(s.configPath(), "stop").Run()
|
||||
}
|
||||
|
||||
func (s *merlinSvc) Restart() error {
|
||||
err := s.Stop()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return s.Start()
|
||||
}
|
||||
|
||||
const merlinSvcScript = `#!/bin/sh
|
||||
|
||||
name="{{.Name}}"
|
||||
cmd="{{.Path}}{{range .Arguments}} {{.}}{{end}}"
|
||||
pid_file="/tmp/$name.pid"
|
||||
|
||||
get_pid() {
|
||||
cat "$pid_file"
|
||||
}
|
||||
|
||||
is_running() {
|
||||
[ -f "$pid_file" ] && ps | grep -q "^ *$(get_pid) "
|
||||
}
|
||||
|
||||
case "$1" in
|
||||
start)
|
||||
if is_running; then
|
||||
logger -c "Already started"
|
||||
else
|
||||
logger -c "Starting $name"
|
||||
if [ -f /rom/ca-bundle.crt ]; then
|
||||
# For John’s fork
|
||||
export SSL_CERT_FILE=/rom/ca-bundle.crt
|
||||
fi
|
||||
$cmd &
|
||||
echo $! > "$pid_file"
|
||||
chmod 600 "$pid_file"
|
||||
if ! is_running; then
|
||||
logger -c "Failed to start $name"
|
||||
exit 1
|
||||
fi
|
||||
fi
|
||||
;;
|
||||
stop)
|
||||
if is_running; then
|
||||
logger -c "Stopping $name..."
|
||||
kill "$(get_pid)"
|
||||
for _ in 1 2 3 4 5; do
|
||||
if ! is_running; then
|
||||
logger -c "stopped"
|
||||
if [ -f "$pid_file" ]; then
|
||||
rm "$pid_file"
|
||||
fi
|
||||
exit 0
|
||||
fi
|
||||
printf "."
|
||||
sleep 2
|
||||
done
|
||||
logger -c "failed to stop $name"
|
||||
exit 1
|
||||
fi
|
||||
exit 1
|
||||
;;
|
||||
restart)
|
||||
$0 stop
|
||||
$0 start
|
||||
;;
|
||||
status)
|
||||
if is_running; then
|
||||
echo "running"
|
||||
else
|
||||
echo "stopped"
|
||||
exit 1
|
||||
fi
|
||||
;;
|
||||
service_event)
|
||||
event=$2
|
||||
svc=$3
|
||||
dnsmasq_pid_file=$(sed -n '/pid-file=/s///p' /etc/dnsmasq.conf)
|
||||
|
||||
if [ "$event" = "restart" ] && [ "$svc" = "diskmon" ]; then
|
||||
kill "$(cat "$dnsmasq_pid_file")" >/dev/null 2>&1
|
||||
fi
|
||||
;;
|
||||
*)
|
||||
echo "Usage: $0 {start|stop|restart|status}"
|
||||
exit 1
|
||||
;;
|
||||
esac
|
||||
exit 0
|
||||
`
|
||||
|
||||
const merlinAddLineToScript = `#!/bin/sh
|
||||
|
||||
line=$1
|
||||
file=$2
|
||||
|
||||
. /usr/sbin/helper.sh
|
||||
|
||||
pc_append "$line" "$file"
|
||||
`
|
||||
|
||||
const merlinRemoveLineFromScript = `#!/bin/sh
|
||||
|
||||
line=$1
|
||||
file=$2
|
||||
|
||||
. /usr/sbin/helper.sh
|
||||
|
||||
pc_delete "$line" "$file"
|
||||
`
|
||||
278
internal/router/service_tomato.go
Normal file
278
internal/router/service_tomato.go
Normal file
@@ -0,0 +1,278 @@
|
||||
package router
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"os/signal"
|
||||
"strings"
|
||||
"syscall"
|
||||
"text/template"
|
||||
|
||||
"github.com/kardianos/service"
|
||||
)
|
||||
|
||||
const tomatoNvramScriptWanupKey = "script_wanup"
|
||||
|
||||
type tomatoSvc struct {
|
||||
i service.Interface
|
||||
platform string
|
||||
*service.Config
|
||||
}
|
||||
|
||||
func newTomatoService(i service.Interface, platform string, c *service.Config) (service.Service, error) {
|
||||
s := &tomatoSvc{
|
||||
i: i,
|
||||
platform: platform,
|
||||
Config: c,
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (s *tomatoSvc) String() string {
|
||||
if len(s.DisplayName) > 0 {
|
||||
return s.DisplayName
|
||||
}
|
||||
return s.Name
|
||||
}
|
||||
|
||||
func (s *tomatoSvc) Platform() string {
|
||||
return s.platform
|
||||
}
|
||||
|
||||
func (s *tomatoSvc) configPath() string {
|
||||
path, err := os.Executable()
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return path + ".startup"
|
||||
}
|
||||
|
||||
func (s *tomatoSvc) template() *template.Template {
|
||||
return template.Must(template.New("").Parse(tomatoSvcScript))
|
||||
}
|
||||
|
||||
func (s *tomatoSvc) Install() error {
|
||||
exePath, err := os.Executable()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !strings.HasPrefix(exePath, "/jffs/") {
|
||||
return errors.New("could not install service outside /jffs")
|
||||
}
|
||||
if _, err := nvram("set", "jffs2_on=1"); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := nvram("commit"); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
confPath := s.configPath()
|
||||
if _, err := os.Stat(confPath); err == nil {
|
||||
return fmt.Errorf("already installed: %s", confPath)
|
||||
}
|
||||
|
||||
var to = &struct {
|
||||
*service.Config
|
||||
Path string
|
||||
}{
|
||||
s.Config,
|
||||
exePath,
|
||||
}
|
||||
|
||||
f, err := os.Create(confPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("os.Create: %w", err)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
if err := s.template().Execute(f, to); err != nil {
|
||||
return fmt.Errorf("s.template.Execute: %w", err)
|
||||
}
|
||||
|
||||
if err = os.Chmod(confPath, 0755); err != nil {
|
||||
return fmt.Errorf("os.Chmod: startup script: %w", err)
|
||||
}
|
||||
|
||||
nvramKvMap := nvramInstallKV()
|
||||
old, err := nvram("get", tomatoNvramScriptWanupKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("nvram: %w", err)
|
||||
}
|
||||
nvramKvMap[tomatoNvramScriptWanupKey] = strings.Join([]string{old, s.configPath() + " start"}, "\n")
|
||||
if err := nvramSetKV(nvramKvMap, nvramCtrldInstallKey); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *tomatoSvc) Uninstall() error {
|
||||
if err := os.Remove(s.configPath()); err != nil {
|
||||
return fmt.Errorf("os.Remove: %w", err)
|
||||
}
|
||||
// Restore old configs.
|
||||
if err := nvramRestore(nvramInstallKV(), nvramCtrldInstallKey); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *tomatoSvc) Logger(errs chan<- error) (service.Logger, error) {
|
||||
if service.Interactive() {
|
||||
return service.ConsoleLogger, nil
|
||||
}
|
||||
return s.SystemLogger(errs)
|
||||
}
|
||||
|
||||
func (s *tomatoSvc) SystemLogger(errs chan<- error) (service.Logger, error) {
|
||||
return newSysLogger(s.Name, errs)
|
||||
}
|
||||
|
||||
func (s *tomatoSvc) Run() (err error) {
|
||||
err = s.i.Start(s)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if interactice, _ := isInteractive(); !interactice {
|
||||
signal.Ignore(syscall.SIGHUP)
|
||||
}
|
||||
|
||||
var sigChan = make(chan os.Signal, 1)
|
||||
signal.Notify(sigChan, syscall.SIGTERM, os.Interrupt)
|
||||
<-sigChan
|
||||
|
||||
return s.i.Stop(s)
|
||||
}
|
||||
|
||||
func (s *tomatoSvc) Status() (service.Status, error) {
|
||||
if _, err := os.Stat(s.configPath()); os.IsNotExist(err) {
|
||||
return service.StatusUnknown, service.ErrNotInstalled
|
||||
}
|
||||
out, err := exec.Command(s.configPath(), "status").CombinedOutput()
|
||||
if err != nil {
|
||||
return service.StatusUnknown, err
|
||||
}
|
||||
switch string(bytes.TrimSpace(out)) {
|
||||
case "running":
|
||||
return service.StatusRunning, nil
|
||||
default:
|
||||
return service.StatusStopped, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (s *tomatoSvc) Start() error {
|
||||
return exec.Command(s.configPath(), "start").Run()
|
||||
}
|
||||
|
||||
func (s *tomatoSvc) Stop() error {
|
||||
return exec.Command(s.configPath(), "stop").Run()
|
||||
}
|
||||
|
||||
func (s *tomatoSvc) Restart() error {
|
||||
return exec.Command(s.configPath(), "restart").Run()
|
||||
}
|
||||
|
||||
// https://wiki.freshtomato.org/doku.php/freshtomato_zerotier?s[]=%2Aservice%2A
|
||||
const tomatoSvcScript = `#!/bin/sh
|
||||
|
||||
|
||||
NAME="{{.Name}}"
|
||||
CMD="{{.Path}}{{range .Arguments}} {{.}}{{end}}"
|
||||
LOG_FILE="/var/log/${NAME}.log"
|
||||
PID_FILE="/tmp/$NAME.pid"
|
||||
|
||||
|
||||
alias elog="logger -t $NAME -s"
|
||||
|
||||
|
||||
COND=$1
|
||||
[ $# -eq 0 ] && COND="start"
|
||||
|
||||
get_pid() {
|
||||
cat "$PID_FILE"
|
||||
}
|
||||
|
||||
is_running() {
|
||||
[ -f "$PID_FILE" ] && ps | grep -q "^ *$(get_pid) "
|
||||
}
|
||||
|
||||
start() {
|
||||
if is_running; then
|
||||
elog "$NAME is already running."
|
||||
exit 1
|
||||
fi
|
||||
elog "Starting $NAME Services: "
|
||||
$CMD &
|
||||
echo $! > "$PID_FILE"
|
||||
chmod 600 "$PID_FILE"
|
||||
if is_running; then
|
||||
elog "succeeded."
|
||||
else
|
||||
elog "failed."
|
||||
fi
|
||||
}
|
||||
|
||||
|
||||
stop() {
|
||||
if ! is_running; then
|
||||
elog "$NAME is not running."
|
||||
exit 1
|
||||
fi
|
||||
elog "Shutting down $NAME Services: "
|
||||
kill -SIGTERM "$(get_pid)"
|
||||
for _ in 1 2 3 4 5; do
|
||||
if ! is_running; then
|
||||
if [ -f "$pid_file" ]; then
|
||||
rm "$pid_file"
|
||||
fi
|
||||
return 0
|
||||
fi
|
||||
printf "."
|
||||
sleep 2
|
||||
done
|
||||
if ! is_running; then
|
||||
elog "succeeded."
|
||||
else
|
||||
elog "failed."
|
||||
fi
|
||||
}
|
||||
|
||||
|
||||
do_restart() {
|
||||
stop
|
||||
start
|
||||
}
|
||||
|
||||
|
||||
do_status() {
|
||||
if ! is_running; then
|
||||
echo "stopped"
|
||||
else
|
||||
echo "running"
|
||||
fi
|
||||
}
|
||||
|
||||
|
||||
case "$COND" in
|
||||
start)
|
||||
start
|
||||
;;
|
||||
stop)
|
||||
stop
|
||||
;;
|
||||
restart)
|
||||
do_restart
|
||||
;;
|
||||
status)
|
||||
do_status
|
||||
;;
|
||||
*)
|
||||
elog "Usage: $0 (start|stop|restart|status)"
|
||||
;;
|
||||
esac
|
||||
exit 0
|
||||
`
|
||||
336
internal/router/service_ubios.go
Normal file
336
internal/router/service_ubios.go
Normal file
@@ -0,0 +1,336 @@
|
||||
package router
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"os/signal"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"syscall"
|
||||
"text/template"
|
||||
"time"
|
||||
|
||||
"github.com/kardianos/service"
|
||||
)
|
||||
|
||||
// This is a copy of https://github.com/kardianos/service/blob/v1.2.1/service_sysv_linux.go,
|
||||
// with modification for supporting ubios v1 init system.
|
||||
|
||||
type ubiosSvc struct {
|
||||
i service.Interface
|
||||
platform string
|
||||
*service.Config
|
||||
}
|
||||
|
||||
func newUbiosService(i service.Interface, platform string, c *service.Config) (service.Service, error) {
|
||||
s := &ubiosSvc{
|
||||
i: i,
|
||||
platform: platform,
|
||||
Config: c,
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (s *ubiosSvc) String() string {
|
||||
if len(s.DisplayName) > 0 {
|
||||
return s.DisplayName
|
||||
}
|
||||
return s.Name
|
||||
}
|
||||
|
||||
func (s *ubiosSvc) Platform() string {
|
||||
return s.platform
|
||||
}
|
||||
|
||||
func (s *ubiosSvc) configPath() string {
|
||||
return "/etc/init.d/" + s.Config.Name
|
||||
}
|
||||
|
||||
func (s *ubiosSvc) execPath() (string, error) {
|
||||
if len(s.Executable) != 0 {
|
||||
return filepath.Abs(s.Executable)
|
||||
}
|
||||
return os.Executable()
|
||||
}
|
||||
|
||||
func (s *ubiosSvc) template() *template.Template {
|
||||
return template.Must(template.New("").Funcs(tf).Parse(ubiosSvcScript))
|
||||
}
|
||||
|
||||
func (s *ubiosSvc) Install() error {
|
||||
confPath := s.configPath()
|
||||
if _, err := os.Stat(confPath); err == nil {
|
||||
return fmt.Errorf("init already exists: %s", confPath)
|
||||
}
|
||||
|
||||
f, err := os.Create(confPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create config path: %w", err)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
path, err := s.execPath()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get exec path: %w", err)
|
||||
}
|
||||
|
||||
var to = &struct {
|
||||
*service.Config
|
||||
Path string
|
||||
DnsMasqConfPath string
|
||||
}{
|
||||
s.Config,
|
||||
path,
|
||||
ubiosDNSMasqConfigPath,
|
||||
}
|
||||
|
||||
if err := s.template().Execute(f, to); err != nil {
|
||||
return fmt.Errorf("failed to create init script: %w", err)
|
||||
}
|
||||
|
||||
if err := f.Close(); err != nil {
|
||||
return fmt.Errorf("failed to save init script: %w", err)
|
||||
}
|
||||
|
||||
if err = os.Chmod(confPath, 0755); err != nil {
|
||||
return fmt.Errorf("failed to set init script executable: %w", err)
|
||||
}
|
||||
|
||||
// Enable on boot
|
||||
script, err := os.CreateTemp("", "ctrld_boot.service")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create boot service tmp file: %w", err)
|
||||
}
|
||||
defer script.Close()
|
||||
|
||||
svcConfig := *to.Config
|
||||
svcConfig.Arguments = os.Args[1:]
|
||||
to.Config = &svcConfig
|
||||
if err := template.Must(template.New("").Funcs(tf).Parse(ubiosBootSystemdService)).Execute(script, &to); err != nil {
|
||||
return fmt.Errorf("failed to create boot service file: %w", err)
|
||||
}
|
||||
if err := script.Close(); err != nil {
|
||||
return fmt.Errorf("failed to save boot service file: %w", err)
|
||||
}
|
||||
|
||||
// Copy the boot script to container and start.
|
||||
cmd := exec.Command("podman", "cp", "--pause=false", script.Name(), "unifi-os:/lib/systemd/system/ctrld-boot.service")
|
||||
if out, err := cmd.CombinedOutput(); err != nil {
|
||||
return fmt.Errorf("failed to copy boot script, out: %s, err: %v", string(out), err)
|
||||
}
|
||||
cmd = exec.Command("podman", "exec", "unifi-os", "systemctl", "enable", "--now", "ctrld-boot.service")
|
||||
if out, err := cmd.CombinedOutput(); err != nil {
|
||||
return fmt.Errorf("failed to start ctrld boot script, out: %s, err: %v", string(out), err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *ubiosSvc) Uninstall() error {
|
||||
if err := os.Remove(s.configPath()); err != nil {
|
||||
return err
|
||||
}
|
||||
// Remove ctrld-boot service inside unifi-os container.
|
||||
cmd := exec.Command("podman", "exec", "unifi-os", "systemctl", "disable", "ctrld-boot.service")
|
||||
if out, err := cmd.CombinedOutput(); err != nil {
|
||||
return fmt.Errorf("failed to disable ctrld-boot service, out: %s, err: %v", string(out), err)
|
||||
}
|
||||
cmd = exec.Command("podman", "exec", "unifi-os", "rm", "/lib/systemd/system/ctrld-boot.service")
|
||||
if out, err := cmd.CombinedOutput(); err != nil {
|
||||
return fmt.Errorf("failed to remove ctrld-boot service file, out: %s, err: %v", string(out), err)
|
||||
}
|
||||
cmd = exec.Command("podman", "exec", "unifi-os", "systemctl", "daemon-reload")
|
||||
if out, err := cmd.CombinedOutput(); err != nil {
|
||||
return fmt.Errorf("failed to reload systemd service, out: %s, err: %v", string(out), err)
|
||||
}
|
||||
cmd = exec.Command("podman", "exec", "unifi-os", "systemctl", "reset-failed")
|
||||
if out, err := cmd.CombinedOutput(); err != nil {
|
||||
return fmt.Errorf("failed to reset-failed systemd service, out: %s, err: %v", string(out), err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *ubiosSvc) Logger(errs chan<- error) (service.Logger, error) {
|
||||
if service.Interactive() {
|
||||
return service.ConsoleLogger, nil
|
||||
}
|
||||
return s.SystemLogger(errs)
|
||||
}
|
||||
|
||||
func (s *ubiosSvc) SystemLogger(errs chan<- error) (service.Logger, error) {
|
||||
return newSysLogger(s.Name, errs)
|
||||
}
|
||||
|
||||
func (s *ubiosSvc) Run() (err error) {
|
||||
err = s.i.Start(s)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if interactice, _ := isInteractive(); !interactice {
|
||||
signal.Ignore(syscall.SIGHUP)
|
||||
}
|
||||
|
||||
var sigChan = make(chan os.Signal, 3)
|
||||
signal.Notify(sigChan, syscall.SIGTERM, os.Interrupt)
|
||||
<-sigChan
|
||||
|
||||
return s.i.Stop(s)
|
||||
}
|
||||
|
||||
func (s *ubiosSvc) Status() (service.Status, error) {
|
||||
if _, err := os.Stat(s.configPath()); os.IsNotExist(err) {
|
||||
return service.StatusUnknown, service.ErrNotInstalled
|
||||
}
|
||||
out, err := exec.Command(s.configPath(), "status").CombinedOutput()
|
||||
if err != nil {
|
||||
return service.StatusUnknown, err
|
||||
}
|
||||
switch string(bytes.TrimSpace(out)) {
|
||||
case "Running":
|
||||
return service.StatusRunning, nil
|
||||
default:
|
||||
return service.StatusStopped, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ubiosSvc) Start() error {
|
||||
return exec.Command(s.configPath(), "start").Run()
|
||||
}
|
||||
|
||||
func (s *ubiosSvc) Stop() error {
|
||||
return exec.Command(s.configPath(), "stop").Run()
|
||||
}
|
||||
|
||||
func (s *ubiosSvc) Restart() error {
|
||||
err := s.Stop()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
return s.Start()
|
||||
}
|
||||
|
||||
const ubiosBootSystemdService = `[Unit]
|
||||
Description=Run ctrld On Startup UDM
|
||||
Wants=network-online.target
|
||||
After=network-online.target
|
||||
StartLimitIntervalSec=500
|
||||
StartLimitBurst=5
|
||||
|
||||
[Service]
|
||||
Restart=on-failure
|
||||
RestartSec=5s
|
||||
ExecStart=/sbin/ssh-proxy '[ -f "{{.DnsMasqConfPath}}" ] || {{.Path}}{{range .Arguments}} {{.|cmd}}{{end}}'
|
||||
RemainAfterExit=true
|
||||
[Install]
|
||||
WantedBy=multi-user.target
|
||||
`
|
||||
|
||||
const ubiosSvcScript = `#!/bin/sh
|
||||
# For RedHat and cousins:
|
||||
# chkconfig: - 99 01
|
||||
# description: {{.Description}}
|
||||
# processname: {{.Path}}
|
||||
|
||||
### BEGIN INIT INFO
|
||||
# Provides: {{.Path}}
|
||||
# Required-Start:
|
||||
# Required-Stop:
|
||||
# Default-Start: 2 3 4 5
|
||||
# Default-Stop: 0 1 6
|
||||
# Short-Description: {{.DisplayName}}
|
||||
# Description: {{.Description}}
|
||||
### END INIT INFO
|
||||
|
||||
cmd="{{.Path}}{{range .Arguments}} {{.|cmd}}{{end}}"
|
||||
|
||||
name=$(basename $(readlink -f $0))
|
||||
pid_file="/var/run/$name.pid"
|
||||
stdout_log="/var/log/$name.log"
|
||||
stderr_log="/var/log/$name.err"
|
||||
|
||||
[ -e /etc/sysconfig/$name ] && . /etc/sysconfig/$name
|
||||
|
||||
get_pid() {
|
||||
cat "$pid_file"
|
||||
}
|
||||
|
||||
is_running() {
|
||||
[ -f "$pid_file" ] && cat /proc/$(get_pid)/stat > /dev/null 2>&1
|
||||
}
|
||||
|
||||
case "$1" in
|
||||
start)
|
||||
if is_running; then
|
||||
echo "Already started"
|
||||
else
|
||||
echo "Starting $name"
|
||||
{{if .WorkingDirectory}}cd '{{.WorkingDirectory}}'{{end}}
|
||||
$cmd >> "$stdout_log" 2>> "$stderr_log" &
|
||||
echo $! > "$pid_file"
|
||||
if ! is_running; then
|
||||
echo "Unable to start, see $stdout_log and $stderr_log"
|
||||
exit 1
|
||||
fi
|
||||
fi
|
||||
;;
|
||||
stop)
|
||||
if is_running; then
|
||||
echo -n "Stopping $name.."
|
||||
kill $(get_pid)
|
||||
for i in $(seq 1 10)
|
||||
do
|
||||
if ! is_running; then
|
||||
break
|
||||
fi
|
||||
echo -n "."
|
||||
sleep 1
|
||||
done
|
||||
echo
|
||||
if is_running; then
|
||||
echo "Not stopped; may still be shutting down or shutdown may have failed"
|
||||
exit 1
|
||||
else
|
||||
echo "Stopped"
|
||||
if [ -f "$pid_file" ]; then
|
||||
rm "$pid_file"
|
||||
fi
|
||||
fi
|
||||
else
|
||||
echo "Not running"
|
||||
fi
|
||||
;;
|
||||
restart)
|
||||
$0 stop
|
||||
if is_running; then
|
||||
echo "Unable to stop, will not attempt to start"
|
||||
exit 1
|
||||
fi
|
||||
$0 start
|
||||
;;
|
||||
status)
|
||||
if is_running; then
|
||||
echo "Running"
|
||||
else
|
||||
echo "Stopped"
|
||||
exit 1
|
||||
fi
|
||||
;;
|
||||
*)
|
||||
echo "Usage: $0 {start|stop|restart|status}"
|
||||
exit 1
|
||||
;;
|
||||
esac
|
||||
exit 0
|
||||
`
|
||||
|
||||
var tf = map[string]interface{}{
|
||||
"cmd": func(s string) string {
|
||||
return `"` + strings.Replace(s, `"`, `\"`, -1) + `"`
|
||||
},
|
||||
"cmdEscape": func(s string) string {
|
||||
return strings.Replace(s, " ", `\x20`, -1)
|
||||
},
|
||||
}
|
||||
55
internal/router/synology.go
Normal file
55
internal/router/synology.go
Normal file
@@ -0,0 +1,55 @@
|
||||
package router
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
)
|
||||
|
||||
const (
|
||||
synologyDNSMasqConfigPath = "/etc/dhcpd/dhcpd-zzz-ctrld.conf"
|
||||
synologyDhcpdInfoPath = "/etc/dhcpd/dhcpd-zzz-ctrld.info"
|
||||
)
|
||||
|
||||
func setupSynology() error {
|
||||
dnsMasqConfigContent, err := dnsMasqConf()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := os.WriteFile(synologyDNSMasqConfigPath, []byte(dnsMasqConfigContent), 0600); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := os.WriteFile(synologyDhcpdInfoPath, []byte(`enable="yes"`), 0600); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := restartDNSMasq(); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func cleanupSynology() error {
|
||||
// Remove the custom config files.
|
||||
for _, f := range []string{synologyDNSMasqConfigPath, synologyDhcpdInfoPath} {
|
||||
if err := os.Remove(f); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Restart dnsmasq service.
|
||||
if err := restartDNSMasq(); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func postInstallSynology() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func synologyRestartDNSMasq() error {
|
||||
if out, err := exec.Command("/etc/rc.network", "nat-restart-dhcp").CombinedOutput(); err != nil {
|
||||
return fmt.Errorf("synologyRestartDNSMasq: %s - %w", string(out), err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
49
internal/router/syslog.go
Normal file
49
internal/router/syslog.go
Normal file
@@ -0,0 +1,49 @@
|
||||
//go:build linux || darwin || freebsd
|
||||
|
||||
package router
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log/syslog"
|
||||
|
||||
"github.com/kardianos/service"
|
||||
)
|
||||
|
||||
func newSysLogger(name string, errs chan<- error) (service.Logger, error) {
|
||||
w, err := syslog.New(syslog.LOG_INFO, name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return sysLogger{w, errs}, nil
|
||||
}
|
||||
|
||||
type sysLogger struct {
|
||||
*syslog.Writer
|
||||
errs chan<- error
|
||||
}
|
||||
|
||||
func (s sysLogger) send(err error) error {
|
||||
if err != nil && s.errs != nil {
|
||||
s.errs <- err
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (s sysLogger) Error(v ...interface{}) error {
|
||||
return s.send(s.Writer.Err(fmt.Sprint(v...)))
|
||||
}
|
||||
func (s sysLogger) Warning(v ...interface{}) error {
|
||||
return s.send(s.Writer.Warning(fmt.Sprint(v...)))
|
||||
}
|
||||
func (s sysLogger) Info(v ...interface{}) error {
|
||||
return s.send(s.Writer.Info(fmt.Sprint(v...)))
|
||||
}
|
||||
func (s sysLogger) Errorf(format string, a ...interface{}) error {
|
||||
return s.send(s.Writer.Err(fmt.Sprintf(format, a...)))
|
||||
}
|
||||
func (s sysLogger) Warningf(format string, a ...interface{}) error {
|
||||
return s.send(s.Writer.Warning(fmt.Sprintf(format, a...)))
|
||||
}
|
||||
func (s sysLogger) Infof(format string, a ...interface{}) error {
|
||||
return s.send(s.Writer.Info(fmt.Sprintf(format, a...)))
|
||||
}
|
||||
7
internal/router/syslog_windows.go
Normal file
7
internal/router/syslog_windows.go
Normal file
@@ -0,0 +1,7 @@
|
||||
package router
|
||||
|
||||
import "github.com/kardianos/service"
|
||||
|
||||
func newSysLogger(name string, errs chan<- error) (service.Logger, error) {
|
||||
return service.ConsoleLogger, nil
|
||||
}
|
||||
82
internal/router/tomato.go
Normal file
82
internal/router/tomato.go
Normal file
@@ -0,0 +1,82 @@
|
||||
package router
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os/exec"
|
||||
)
|
||||
|
||||
const (
|
||||
tomatoDnsCryptProxySvcName = "dnscrypt-proxy"
|
||||
tomatoStubbySvcName = "stubby"
|
||||
tomatoDNSMasqSvcName = "dnsmasq"
|
||||
)
|
||||
|
||||
func setupTomato() error {
|
||||
// Already setup.
|
||||
if val, _ := nvram("get", nvramCtrldSetupKey); val == "1" {
|
||||
return nil
|
||||
}
|
||||
|
||||
data, err := dnsMasqConf()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
nvramKvMap := nvramSetupKV()
|
||||
nvramKvMap["dnsmasq_custom"] = data
|
||||
if err := nvramSetKV(nvramKvMap, nvramCtrldSetupKey); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Restart dnscrypt-proxy service.
|
||||
if err := tomatoRestartServiceWithKill(tomatoDnsCryptProxySvcName, true); err != nil {
|
||||
return err
|
||||
}
|
||||
// Restart stubby service.
|
||||
if err := tomatoRestartService(tomatoStubbySvcName); err != nil {
|
||||
return err
|
||||
}
|
||||
// Restart dnsmasq service.
|
||||
if err := restartDNSMasq(); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func postInstallTomato() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func cleanupTomato() error {
|
||||
// Restore old configs.
|
||||
if err := nvramRestore(nvramSetupKV(), nvramCtrldSetupKey); err != nil {
|
||||
return err
|
||||
}
|
||||
// Restart dnscrypt-proxy service.
|
||||
if err := tomatoRestartServiceWithKill(tomatoDnsCryptProxySvcName, true); err != nil {
|
||||
return err
|
||||
}
|
||||
// Restart stubby service.
|
||||
if err := tomatoRestartService(tomatoStubbySvcName); err != nil {
|
||||
return err
|
||||
}
|
||||
// Restart dnsmasq service.
|
||||
if err := restartDNSMasq(); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func tomatoRestartService(name string) error {
|
||||
return tomatoRestartServiceWithKill(name, false)
|
||||
}
|
||||
|
||||
func tomatoRestartServiceWithKill(name string, killBeforeRestart bool) error {
|
||||
if killBeforeRestart {
|
||||
_, _ = exec.Command("killall", name).CombinedOutput()
|
||||
}
|
||||
if out, err := exec.Command("service", name, "restart").CombinedOutput(); err != nil {
|
||||
return fmt.Errorf("service restart %s: %s, %w", name, string(out), err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
73
internal/router/ubios.go
Normal file
73
internal/router/ubios.go
Normal file
@@ -0,0 +1,73 @@
|
||||
package router
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"os"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
var errContentFilteringEnabled = fmt.Errorf(`the "Content Filtering" feature" is enabled, which is conflicted with ctrld.\n
|
||||
To disable it, folowing instruction here: %s`, toggleContentFilteringLink)
|
||||
|
||||
const (
|
||||
ubiosDNSMasqConfigPath = "/run/dnsmasq.conf.d/zzzctrld.conf"
|
||||
toggleContentFilteringLink = "https://community.ui.com/questions/UDM-Pro-disable-enable-DNS-filtering/e2cc4060-e56a-4139-b200-62d7f773ff8f"
|
||||
)
|
||||
|
||||
func setupUbiOS() error {
|
||||
// Disable dnsmasq as DNS server.
|
||||
dnsMasqConfigContent, err := dnsMasqConf()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := os.WriteFile(ubiosDNSMasqConfigPath, []byte(dnsMasqConfigContent), 0600); err != nil {
|
||||
return err
|
||||
}
|
||||
// Restart dnsmasq service.
|
||||
if err := restartDNSMasq(); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func cleanupUbiOS() error {
|
||||
// Remove the custom dnsmasq config
|
||||
if err := os.Remove(ubiosDNSMasqConfigPath); err != nil {
|
||||
return err
|
||||
}
|
||||
// Restart dnsmasq service.
|
||||
if err := restartDNSMasq(); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func postInstallUbiOS() error {
|
||||
// See comment in postInstallEdgeOS.
|
||||
if contentFilteringEnabled() {
|
||||
return errContentFilteringEnabled
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func ubiosRestartDNSMasq() error {
|
||||
buf, err := os.ReadFile("/run/dnsmasq.pid")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
pid, err := strconv.ParseUint(string(bytes.TrimSpace(buf)), 10, 64)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
proc, err := os.FindProcess(int(pid))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return proc.Kill()
|
||||
}
|
||||
|
||||
func contentFilteringEnabled() bool {
|
||||
st, err := os.Stat("/run/dnsfilter/dnsfilter")
|
||||
return err == nil && !st.IsDir()
|
||||
}
|
||||
29
nameservers.go
Normal file
29
nameservers.go
Normal file
@@ -0,0 +1,29 @@
|
||||
package ctrld
|
||||
|
||||
import "net"
|
||||
|
||||
type dnsFn func() []string
|
||||
|
||||
func nameservers() []string {
|
||||
var dns []string
|
||||
seen := make(map[string]bool)
|
||||
ch := make(chan []string)
|
||||
fns := dnsFns()
|
||||
|
||||
for _, fn := range fns {
|
||||
go func(fn dnsFn) {
|
||||
ch <- fn()
|
||||
}(fn)
|
||||
}
|
||||
for range fns {
|
||||
for _, ns := range <-ch {
|
||||
if seen[ns] {
|
||||
continue
|
||||
}
|
||||
seen[ns] = true
|
||||
dns = append(dns, net.JoinHostPort(ns, "53"))
|
||||
}
|
||||
}
|
||||
|
||||
return dns
|
||||
}
|
||||
75
nameservers_bsd.go
Normal file
75
nameservers_bsd.go
Normal file
@@ -0,0 +1,75 @@
|
||||
//go:build darwin || dragonfly || freebsd || netbsd || openbsd
|
||||
|
||||
package ctrld
|
||||
|
||||
import (
|
||||
"net"
|
||||
"os/exec"
|
||||
"runtime"
|
||||
"strings"
|
||||
"syscall"
|
||||
|
||||
"golang.org/x/net/route"
|
||||
)
|
||||
|
||||
func dnsFns() []dnsFn {
|
||||
return []dnsFn{dnsFromRIB, dnsFromIPConfig}
|
||||
}
|
||||
|
||||
func dnsFromRIB() []string {
|
||||
var dns []string
|
||||
rib, err := route.FetchRIB(syscall.AF_UNSPEC, route.RIBTypeRoute, 0)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
messages, err := route.ParseRIB(route.RIBTypeRoute, rib)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
for _, message := range messages {
|
||||
message, ok := message.(*route.RouteMessage)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
addresses := message.Addrs
|
||||
if len(addresses) < 2 {
|
||||
continue
|
||||
}
|
||||
dst, gw := toNetIP(addresses[0]), toNetIP(addresses[1])
|
||||
if dst == nil || gw == nil {
|
||||
continue
|
||||
}
|
||||
if gw.IsLoopback() {
|
||||
continue
|
||||
}
|
||||
if dst.Equal(net.IPv4zero) || dst.Equal(net.IPv6zero) {
|
||||
dns = append(dns, gw.String())
|
||||
}
|
||||
}
|
||||
return dns
|
||||
}
|
||||
|
||||
func dnsFromIPConfig() []string {
|
||||
if runtime.GOOS != "darwin" {
|
||||
return nil
|
||||
}
|
||||
cmd := exec.Command("ipconfig", "getoption", "", "domain_name_server")
|
||||
out, _ := cmd.Output()
|
||||
if ip := net.ParseIP(strings.TrimSpace(string(out))); ip != nil {
|
||||
return []string{ip.String()}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func toNetIP(addr route.Addr) net.IP {
|
||||
switch t := addr.(type) {
|
||||
case *route.Inet4Addr:
|
||||
return net.IPv4(t.IP[0], t.IP[1], t.IP[2], t.IP[3])
|
||||
case *route.Inet6Addr:
|
||||
ip := make(net.IP, net.IPv6len)
|
||||
copy(ip, t.IP[:])
|
||||
return ip
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
97
nameservers_linux.go
Normal file
97
nameservers_linux.go
Normal file
@@ -0,0 +1,97 @@
|
||||
package ctrld
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"encoding/hex"
|
||||
"net"
|
||||
"os"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld/internal/dns/resolvconffile"
|
||||
)
|
||||
|
||||
const (
|
||||
v4RouteFile = "/proc/net/route"
|
||||
v6RouteFile = "/proc/net/ipv6_route"
|
||||
)
|
||||
|
||||
func dnsFns() []dnsFn {
|
||||
return []dnsFn{dns4, dns6, dnsFromSystemdResolver}
|
||||
}
|
||||
|
||||
func dns4() []string {
|
||||
f, err := os.Open(v4RouteFile)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
var dns []string
|
||||
seen := make(map[string]bool)
|
||||
s := bufio.NewScanner(f)
|
||||
first := true
|
||||
for s.Scan() {
|
||||
if first {
|
||||
first = false
|
||||
continue
|
||||
}
|
||||
fields := bytes.Fields(s.Bytes())
|
||||
if len(fields) < 2 {
|
||||
continue
|
||||
}
|
||||
|
||||
gw := make([]byte, net.IPv4len)
|
||||
// Third fields is gateway.
|
||||
if _, err := hex.Decode(gw, fields[2]); err != nil {
|
||||
continue
|
||||
}
|
||||
ip := net.IPv4(gw[3], gw[2], gw[1], gw[0])
|
||||
if ip.Equal(net.IPv4zero) || seen[ip.String()] {
|
||||
continue
|
||||
}
|
||||
seen[ip.String()] = true
|
||||
dns = append(dns, ip.String())
|
||||
}
|
||||
return dns
|
||||
}
|
||||
|
||||
func dns6() []string {
|
||||
f, err := os.Open(v6RouteFile)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
var dns []string
|
||||
s := bufio.NewScanner(f)
|
||||
for s.Scan() {
|
||||
fields := bytes.Fields(s.Bytes())
|
||||
if len(fields) < 4 {
|
||||
continue
|
||||
}
|
||||
|
||||
gw := make([]byte, net.IPv6len)
|
||||
// Fifth fields is gateway.
|
||||
if _, err := hex.Decode(gw, fields[4]); err != nil {
|
||||
continue
|
||||
}
|
||||
ip := net.IP(gw)
|
||||
if ip.Equal(net.IPv6zero) {
|
||||
continue
|
||||
}
|
||||
dns = append(dns, ip.String())
|
||||
}
|
||||
return dns
|
||||
}
|
||||
|
||||
func dnsFromSystemdResolver() []string {
|
||||
c, err := resolvconffile.ParseFile("/run/systemd/resolve/resolv.conf")
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
ns := make([]string, 0, len(c.Nameservers))
|
||||
for _, nameserver := range c.Nameservers {
|
||||
ns = append(ns, nameserver.String())
|
||||
}
|
||||
return ns
|
||||
}
|
||||
11
nameservers_test.go
Normal file
11
nameservers_test.go
Normal file
@@ -0,0 +1,11 @@
|
||||
package ctrld
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestNameservers(t *testing.T) {
|
||||
ns := nameservers()
|
||||
if len(ns) == 0 {
|
||||
t.Fatal("failed to get nameservers")
|
||||
}
|
||||
t.Log(ns)
|
||||
}
|
||||
@@ -1,11 +0,0 @@
|
||||
//go:build !js && !windows
|
||||
|
||||
package ctrld
|
||||
|
||||
import (
|
||||
"github.com/Control-D-Inc/ctrld/internal/resolvconffile"
|
||||
)
|
||||
|
||||
func nameservers() []string {
|
||||
return resolvconffile.NameServersWithPort()
|
||||
}
|
||||
@@ -2,70 +2,59 @@ package ctrld
|
||||
|
||||
import (
|
||||
"net"
|
||||
"os"
|
||||
"syscall"
|
||||
"unsafe"
|
||||
|
||||
"golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
|
||||
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
func nameservers() []string {
|
||||
aas, err := adapterAddresses()
|
||||
func dnsFns() []dnsFn {
|
||||
return []dnsFn{dnsFromAdapter}
|
||||
}
|
||||
|
||||
func dnsFromAdapter() []string {
|
||||
aas, err := winipcfg.GetAdaptersAddresses(syscall.AF_UNSPEC, winipcfg.GAAFlagIncludeGateways|winipcfg.GAAFlagIncludePrefix)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
ns := make([]string, 0, len(aas))
|
||||
ns := make([]string, 0, len(aas)*2)
|
||||
seen := make(map[string]bool)
|
||||
do := func(addr windows.SocketAddress) {
|
||||
sa, err := addr.Sockaddr.Sockaddr()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
var ip net.IP
|
||||
switch sa := sa.(type) {
|
||||
case *syscall.SockaddrInet4:
|
||||
ip = net.IPv4(sa.Addr[0], sa.Addr[1], sa.Addr[2], sa.Addr[3])
|
||||
case *syscall.SockaddrInet6:
|
||||
ip = make(net.IP, net.IPv6len)
|
||||
copy(ip, sa.Addr[:])
|
||||
if ip[0] == 0xfe && ip[1] == 0xc0 {
|
||||
// Ignore these fec0/10 ones. Windows seems to
|
||||
// populate them as defaults on its misc rando
|
||||
// interfaces.
|
||||
return
|
||||
}
|
||||
default:
|
||||
return
|
||||
|
||||
}
|
||||
if ip.IsLoopback() || seen[ip.String()] {
|
||||
return
|
||||
}
|
||||
seen[ip.String()] = true
|
||||
ns = append(ns, ip.String())
|
||||
}
|
||||
for _, aa := range aas {
|
||||
for dns := aa.FirstDnsServerAddress; dns != nil; dns = dns.Next {
|
||||
sa, err := dns.Address.Sockaddr.Sockaddr()
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
var ip net.IP
|
||||
switch sa := sa.(type) {
|
||||
case *syscall.SockaddrInet4:
|
||||
ip = net.IPv4(sa.Addr[0], sa.Addr[1], sa.Addr[2], sa.Addr[3])
|
||||
case *syscall.SockaddrInet6:
|
||||
ip = make(net.IP, net.IPv6len)
|
||||
copy(ip, sa.Addr[:])
|
||||
if ip[0] == 0xfe && ip[1] == 0xc0 {
|
||||
// Ignore these fec0/10 ones. Windows seems to
|
||||
// populate them as defaults on its misc rando
|
||||
// interfaces.
|
||||
continue
|
||||
}
|
||||
default:
|
||||
// Unexpected type.
|
||||
continue
|
||||
}
|
||||
ns = append(ns, net.JoinHostPort(ip.String(), "53"))
|
||||
for dns := aa.FirstDNSServerAddress; dns != nil; dns = dns.Next {
|
||||
do(dns.Address)
|
||||
}
|
||||
for gw := aa.FirstGatewayAddress; gw != nil; gw = gw.Next {
|
||||
do(gw.Address)
|
||||
}
|
||||
}
|
||||
return ns
|
||||
}
|
||||
|
||||
func adapterAddresses() ([]*windows.IpAdapterAddresses, error) {
|
||||
var b []byte
|
||||
l := uint32(15000) // recommended initial size
|
||||
for {
|
||||
b = make([]byte, l)
|
||||
err := windows.GetAdaptersAddresses(syscall.AF_UNSPEC, windows.GAA_FLAG_INCLUDE_PREFIX, 0, (*windows.IpAdapterAddresses)(unsafe.Pointer(&b[0])), &l)
|
||||
if err == nil {
|
||||
if l == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
break
|
||||
}
|
||||
if err.(syscall.Errno) != syscall.ERROR_BUFFER_OVERFLOW {
|
||||
return nil, os.NewSyscallError("getadaptersaddresses", err)
|
||||
}
|
||||
if l <= uint32(len(b)) {
|
||||
return nil, os.NewSyscallError("getadaptersaddresses", err)
|
||||
}
|
||||
}
|
||||
var aas []*windows.IpAdapterAddresses
|
||||
for aa := (*windows.IpAdapterAddresses)(unsafe.Pointer(&b[0])); aa != nil; aa = aa.Next {
|
||||
aas = append(aas, aa)
|
||||
}
|
||||
return aas, nil
|
||||
}
|
||||
|
||||
46
net.go
Normal file
46
net.go
Normal file
@@ -0,0 +1,46 @@
|
||||
package ctrld
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"tailscale.com/logtail/backoff"
|
||||
|
||||
ctrldnet "github.com/Control-D-Inc/ctrld/internal/net"
|
||||
)
|
||||
|
||||
var (
|
||||
hasIPv6Once sync.Once
|
||||
ipv6Available atomic.Bool
|
||||
)
|
||||
|
||||
func hasIPv6() bool {
|
||||
hasIPv6Once.Do(func() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
val := ctrldnet.IPv6Available(ctx)
|
||||
ipv6Available.Store(val)
|
||||
go probingIPv6(val)
|
||||
})
|
||||
return ipv6Available.Load()
|
||||
}
|
||||
|
||||
// TODO(cuonglm): doing poll check natively for supported platforms.
|
||||
func probingIPv6(old bool) {
|
||||
b := backoff.NewBackoff("probingIPv6", func(format string, args ...any) {}, 30*time.Second)
|
||||
bCtx := context.Background()
|
||||
for {
|
||||
func() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
cur := ctrldnet.IPv6Available(ctx)
|
||||
if ipv6Available.CompareAndSwap(old, cur) {
|
||||
old = cur
|
||||
}
|
||||
}()
|
||||
b.BackOff(bCtx, errors.New("no change"))
|
||||
}
|
||||
}
|
||||
131
resolver.go
131
resolver.go
@@ -5,16 +5,24 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
const (
|
||||
ResolverTypeDOH = "doh"
|
||||
ResolverTypeDOH3 = "doh3"
|
||||
ResolverTypeDOT = "dot"
|
||||
ResolverTypeDOQ = "doq"
|
||||
ResolverTypeOS = "os"
|
||||
// ResolverTypeDOH specifies DoH resolver.
|
||||
ResolverTypeDOH = "doh"
|
||||
// ResolverTypeDOH3 specifies DoH3 resolver.
|
||||
ResolverTypeDOH3 = "doh3"
|
||||
// ResolverTypeDOT specifies DoT resolver.
|
||||
ResolverTypeDOT = "dot"
|
||||
// ResolverTypeDOQ specifies DoQ resolver.
|
||||
ResolverTypeDOQ = "doq"
|
||||
// ResolverTypeOS specifies OS resolver.
|
||||
ResolverTypeOS = "os"
|
||||
// ResolverTypeLegacy specifies legacy resolver.
|
||||
ResolverTypeLegacy = "legacy"
|
||||
)
|
||||
|
||||
@@ -32,7 +40,7 @@ var errUnknownResolver = errors.New("unknown resolver")
|
||||
|
||||
// NewResolver creates a Resolver based on the given upstream config.
|
||||
func NewResolver(uc *UpstreamConfig) (Resolver, error) {
|
||||
typ, endpoint := uc.Type, uc.Endpoint
|
||||
typ := uc.Type
|
||||
switch typ {
|
||||
case ResolverTypeDOH, ResolverTypeDOH3:
|
||||
return newDohResolver(uc), nil
|
||||
@@ -43,7 +51,7 @@ func NewResolver(uc *UpstreamConfig) (Resolver, error) {
|
||||
case ResolverTypeOS:
|
||||
return or, nil
|
||||
case ResolverTypeLegacy:
|
||||
return &legacyResolver{endpoint: endpoint}, nil
|
||||
return &legacyResolver{uc: uc}, nil
|
||||
}
|
||||
return nil, fmt.Errorf("%w: %s", errUnknownResolver, typ)
|
||||
}
|
||||
@@ -69,9 +77,16 @@ func (o *osResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error
|
||||
|
||||
dnsClient := &dns.Client{Net: "udp"}
|
||||
ch := make(chan *osResolverResult, numServers)
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(len(o.nameservers))
|
||||
go func() {
|
||||
wg.Wait()
|
||||
close(ch)
|
||||
}()
|
||||
for _, server := range o.nameservers {
|
||||
go func(server string) {
|
||||
answer, _, err := dnsClient.ExchangeContext(ctx, msg, server)
|
||||
defer wg.Done()
|
||||
answer, _, err := dnsClient.ExchangeContext(ctx, msg.Copy(), server)
|
||||
ch <- &osResolverResult{answer: answer, err: err}
|
||||
}(server)
|
||||
}
|
||||
@@ -85,7 +100,7 @@ func (o *osResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error
|
||||
errs = append(errs, res.err)
|
||||
}
|
||||
|
||||
return nil, joinErrors(errs...)
|
||||
return nil, errors.Join(errs...)
|
||||
}
|
||||
|
||||
func newDialer(dnsAddress string) *net.Dialer {
|
||||
@@ -101,16 +116,108 @@ func newDialer(dnsAddress string) *net.Dialer {
|
||||
}
|
||||
|
||||
type legacyResolver struct {
|
||||
endpoint string
|
||||
uc *UpstreamConfig
|
||||
}
|
||||
|
||||
func (r *legacyResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) {
|
||||
// See comment in (*dotResolver).resolve method.
|
||||
dialer := newDialer(net.JoinHostPort(bootstrapDNS, "53"))
|
||||
dnsTyp := uint16(0)
|
||||
if msg != nil && len(msg.Question) > 0 {
|
||||
dnsTyp = msg.Question[0].Qtype
|
||||
}
|
||||
_, udpNet := r.uc.netForDNSType(dnsTyp)
|
||||
dnsClient := &dns.Client{
|
||||
Net: "udp",
|
||||
Net: udpNet,
|
||||
Dialer: dialer,
|
||||
}
|
||||
answer, _, err := dnsClient.ExchangeContext(ctx, msg, r.endpoint)
|
||||
endpoint := r.uc.Endpoint
|
||||
if r.uc.BootstrapIP != "" {
|
||||
dnsClient.Net = "udp"
|
||||
_, port, _ := net.SplitHostPort(endpoint)
|
||||
endpoint = net.JoinHostPort(r.uc.BootstrapIP, port)
|
||||
}
|
||||
|
||||
answer, _, err := dnsClient.ExchangeContext(ctx, msg, endpoint)
|
||||
return answer, err
|
||||
}
|
||||
|
||||
// LookupIP looks up host using OS resolver.
|
||||
// It returns a slice of that host's IPv4 and IPv6 addresses.
|
||||
func LookupIP(domain string) []string {
|
||||
return lookupIP(domain, -1, true)
|
||||
}
|
||||
|
||||
func lookupIP(domain string, timeout int, withBootstrapDNS bool) (ips []string) {
|
||||
resolver := &osResolver{nameservers: nameservers()}
|
||||
if withBootstrapDNS {
|
||||
resolver.nameservers = append([]string{net.JoinHostPort(bootstrapDNS, "53")}, resolver.nameservers...)
|
||||
}
|
||||
ProxyLog.Debug().Msgf("Resolving %q using bootstrap DNS %q", domain, resolver.nameservers)
|
||||
timeoutMs := 2000
|
||||
if timeout > 0 && timeout < timeoutMs {
|
||||
timeoutMs = timeout
|
||||
}
|
||||
questionDomain := dns.Fqdn(domain)
|
||||
ipFromRecord := func(record dns.RR) string {
|
||||
switch ar := record.(type) {
|
||||
case *dns.A:
|
||||
if ar.Hdr.Name != questionDomain {
|
||||
return ""
|
||||
}
|
||||
return ar.A.String()
|
||||
case *dns.AAAA:
|
||||
if ar.Hdr.Name != questionDomain {
|
||||
return ""
|
||||
}
|
||||
return ar.AAAA.String()
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
lookup := func(dnsType uint16) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(timeoutMs)*time.Millisecond)
|
||||
defer cancel()
|
||||
m := new(dns.Msg)
|
||||
m.SetQuestion(questionDomain, dnsType)
|
||||
m.RecursionDesired = true
|
||||
|
||||
r, err := resolver.Resolve(ctx, m)
|
||||
if err != nil {
|
||||
ProxyLog.Error().Err(err).Msgf("could not lookup %q record for domain %q", dns.TypeToString[dnsType], domain)
|
||||
return
|
||||
}
|
||||
if r.Rcode != dns.RcodeSuccess {
|
||||
ProxyLog.Error().Msgf("could not resolve domain %q, return code: %s", domain, dns.RcodeToString[r.Rcode])
|
||||
return
|
||||
}
|
||||
if len(r.Answer) == 0 {
|
||||
ProxyLog.Error().Msg("no answer from OS resolver")
|
||||
return
|
||||
}
|
||||
for _, a := range r.Answer {
|
||||
if ip := ipFromRecord(a); ip != "" {
|
||||
ips = append(ips, ip)
|
||||
}
|
||||
}
|
||||
}
|
||||
// Find all A, AAAA records of the domain.
|
||||
for _, dnsType := range []uint16{dns.TypeAAAA, dns.TypeA} {
|
||||
lookup(dnsType)
|
||||
}
|
||||
return ips
|
||||
}
|
||||
|
||||
// NewBootstrapResolver returns an OS resolver, which use following nameservers:
|
||||
//
|
||||
// - ControlD bootstrap DNS server.
|
||||
// - Gateway IP address (depends on OS).
|
||||
// - Input servers.
|
||||
func NewBootstrapResolver(servers ...string) Resolver {
|
||||
resolver := &osResolver{nameservers: nameservers()}
|
||||
resolver.nameservers = append([]string{net.JoinHostPort(bootstrapDNS, "53")}, resolver.nameservers...)
|
||||
for _, ns := range servers {
|
||||
resolver.nameservers = append([]string{net.JoinHostPort(ns, "53")}, resolver.nameservers...)
|
||||
}
|
||||
return resolver
|
||||
}
|
||||
|
||||
53
resolver_test.go
Normal file
53
resolver_test.go
Normal file
@@ -0,0 +1,53 @@
|
||||
package ctrld
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
func Test_osResolver_Resolve(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
go func() {
|
||||
defer cancel()
|
||||
resolver := &osResolver{nameservers: []string{"127.0.0.127:5353"}}
|
||||
m := new(dns.Msg)
|
||||
m.SetQuestion("controld.com.", dns.TypeA)
|
||||
m.RecursionDesired = true
|
||||
_, _ = resolver.Resolve(context.Background(), m)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-time.After(10 * time.Second):
|
||||
t.Error("os resolver hangs")
|
||||
case <-ctx.Done():
|
||||
}
|
||||
}
|
||||
|
||||
func Test_upstreamTypeFromEndpoint(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
endpoint string
|
||||
resolverType string
|
||||
}{
|
||||
{"doh", "https://freedns.controld.com/p2", ResolverTypeDOH},
|
||||
{"doq", "quic://p2.freedns.controld.com", ResolverTypeDOQ},
|
||||
{"dot", "p2.freedns.controld.com", ResolverTypeDOT},
|
||||
{"legacy", "8.8.8.8:53", ResolverTypeLegacy},
|
||||
{"legacy ipv6", "[2404:6800:4005:809::200e]:53", ResolverTypeLegacy},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
if rt := ResolverTypeFromEndpoint(tc.endpoint); rt != tc.resolverType {
|
||||
t.Errorf("mismatch, want: %s, got: %s", tc.resolverType, rt)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -50,6 +50,13 @@ type = "legacy"
|
||||
endpoint = "8.8.8.8"
|
||||
timeout = 5
|
||||
|
||||
[upstream.3]
|
||||
name = "DOH with client info"
|
||||
type = "doh"
|
||||
endpoint = "https://dns.controld.com/client_info_upstream/main-device"
|
||||
timeout = 5
|
||||
send_client_info = false
|
||||
|
||||
[listener.0]
|
||||
ip = "127.0.0.1"
|
||||
port = 53
|
||||
|
||||
Reference in New Issue
Block a user