all: add starting service with Control D config

This commit is contained in:
Cuong Manh Le
2022-12-22 23:27:45 +07:00
committed by Cuong Manh Le
parent ec72af1916
commit 114ef9aad6
7 changed files with 216 additions and 23 deletions

View File

@@ -17,6 +17,7 @@ import (
"github.com/spf13/viper"
"github.com/Control-D-Inc/ctrld"
"github.com/Control-D-Inc/ctrld/internal/controld"
)
var (
@@ -61,15 +62,15 @@ func initCLI() {
log.Fatal("Cannot run in daemon mode. Please install a Windows service.")
}
noConfigStart := isNoConfigStart(cmd)
noConfigStart := isNoConfigStart(cmd) && cdUID != ""
writeDefaultConfig := !noConfigStart && configBase64 == ""
configs := []struct {
name string
written bool
}{
// For compatibility, we check for config.toml first, but only read it if exists.
{"config", false},
{"ctrld", !noConfigStart && configBase64 == ""},
{"ctrld", writeDefaultConfig},
}
for _, config := range configs {
ctrld.SetConfigName(v, config.name)
@@ -81,6 +82,7 @@ func initCLI() {
readBase64Config()
processNoConfigFlags(noConfigStart)
processCDFlags()
if err := v.Unmarshal(&cfg); err != nil {
log.Fatalf("failed to unmarshal config: %v", err)
}
@@ -138,6 +140,7 @@ func initCLI() {
runCmd.Flags().StringSliceVarP(&domains, "domains", "", nil, "list of domain to apply in a split DNS policy")
runCmd.Flags().StringVarP(&logPath, "log", "", "", "path to log file")
runCmd.Flags().IntVarP(&cacheSize, "cache_size", "", 0, "Enable cache with size items")
runCmd.Flags().StringVarP(&cdUID, "cd", "", "", "Control D resolver uid")
rootCmd.AddCommand(runCmd)
@@ -183,6 +186,7 @@ func initCLI() {
startCmd.Flags().StringSliceVarP(&domains, "domains", "", nil, "list of domain to apply in a split DNS policy")
startCmd.Flags().StringVarP(&logPath, "log", "", "", "path to log file")
startCmd.Flags().IntVarP(&cacheSize, "cache_size", "", 0, "Enable cache with size items")
startCmd.Flags().StringVarP(&cdUID, "cd", "", "", "Control D resolver uid")
stopCmd := &cobra.Command{
Use: "stop",
@@ -308,21 +312,24 @@ func writeConfigFile() {
}
}
func readConfigFile(configWritten bool) bool {
func readConfigFile(writeDefaultConfig bool) bool {
// If err == nil, there's a config supplied via `--config`, no default config written.
err := v.ReadInConfig()
if err == nil {
return true
}
if !configWritten {
if !writeDefaultConfig {
return false
}
// If error is viper.ConfigFileNotFoundError, write default config.
if _, ok := err.(viper.ConfigFileNotFoundError); ok {
writeConfigFile()
defaultConfigWritten = true
return false
}
// Otherwise, report fatal error and exit.
log.Fatalf("failed to decode config file: %v", err)
return false
}
@@ -347,21 +354,7 @@ func processNoConfigFlags(noConfigStart bool) {
if listenAddress == "" || primaryUpstream == "" {
log.Fatal(`"listen" and "primary_upstream" flags must be set in no config mode`)
}
host, portStr, err := net.SplitHostPort(listenAddress)
if err != nil {
log.Fatalf("invalid listener address: %v", err)
}
port, err := strconv.Atoi(portStr)
if err != nil {
log.Fatalf("invalid port number: %v", err)
}
lc := &ctrld.ListenerConfig{
IP: host,
Port: port,
}
v.Set("listener", map[string]*ctrld.ListenerConfig{
"0": lc,
})
processListenFlag()
upstream := map[string]*ctrld.UpstreamConfig{
"0": {
@@ -380,10 +373,67 @@ func processNoConfigFlags(noConfigStart bool) {
for _, domain := range domains {
rules = append(rules, ctrld.Rule{domain: []string{"upstream.1"}})
}
lc := v.Get("listener").(map[string]*ctrld.ListenerConfig)["0"]
lc.Policy = &ctrld.ListenerPolicyConfig{Name: "My Policy", Rules: rules}
}
v.Set("upstream", upstream)
processLogAndCacheFlags()
}
func processCDFlags() {
if cdUID == "" {
return
}
resolverConfig, err := controld.FetchResolverConfig(cdUID)
if err != nil {
log.Fatalf("failed to fetch resolver config: %v", err)
}
upstream := map[string]*ctrld.UpstreamConfig{
"0": {
BootstrapIP: resolverConfig.IP(supportsIPv6()),
Name: resolverConfig.DOH,
Endpoint: resolverConfig.DOH,
Type: ctrld.ResolverTypeDOH,
},
}
v.Set("upstream", upstream)
processListenFlag()
rules := make([]ctrld.Rule, 0, len(resolverConfig.Exclude))
for _, domain := range resolverConfig.Exclude {
rules = append(rules, ctrld.Rule{domain: []string{}})
}
lc := v.Get("listener").(map[string]*ctrld.ListenerConfig)["0"]
lc.Policy = &ctrld.ListenerPolicyConfig{Name: "My Policy", Rules: rules}
processLogAndCacheFlags()
}
func processListenFlag() {
if listenAddress == "" {
return
}
host, portStr, err := net.SplitHostPort(listenAddress)
if err != nil {
log.Fatalf("invalid listener address: %v", err)
}
port, err := strconv.Atoi(portStr)
if err != nil {
log.Fatalf("invalid port number: %v", err)
}
lc := &ctrld.ListenerConfig{
IP: host,
Port: port,
}
v.Set("listener", map[string]*ctrld.ListenerConfig{
"0": lc,
})
}
func processLogAndCacheFlags() {
sc := ctrld.ServiceConfig{}
if logPath != "" {
sc.LogLevel = "debug"

View File

@@ -143,6 +143,10 @@ func (p *prog) proxy(ctx context.Context, upstreams []string, failoverRcodes []i
}
}
upstreamConfigs := p.upstreamConfigsFromUpstreamNumbers(upstreams)
if len(upstreamConfigs) == 0 {
upstreamConfigs = []*ctrld.UpstreamConfig{osUpstreamConfig}
upstreams = []string{"upstream.os"}
}
resolve := func(n int, upstreamConfig *ctrld.UpstreamConfig, msg *dns.Msg) *dns.Msg {
ctrld.Log(ctx, proxyLog.Debug(), "sending query to %s: %s", upstreams[n], upstreamConfig.Name)
dnsResolver, err := ctrld.NewResolver(upstreamConfig)
@@ -204,9 +208,6 @@ func (p *prog) upstreamConfigsFromUpstreamNumbers(upstreams []string) []*ctrld.U
upstreamNum := strings.TrimPrefix(upstream, "upstream.")
upstreamConfigs = append(upstreamConfigs, p.cfg.Upstream[upstreamNum])
}
if len(upstreamConfigs) == 0 {
upstreamConfigs = []*ctrld.UpstreamConfig{osUpstreamConfig}
}
return upstreamConfigs
}

View File

@@ -31,6 +31,8 @@ var (
rootLogger = zerolog.New(io.Discard)
mainLog = rootLogger
proxyLog = rootLogger
cdUID string
)
func main() {

View File

@@ -25,6 +25,8 @@ Usage:
Flags:
--base64_config string base64 encoded config
--cache_size int Enable cache with size items
--cd string Control D resolver uid
-c, --config string Path to config file
-d, --daemon Run as daemon
--domains strings list of domain to apply in a split DNS policy

15
docs/controld_config.md Normal file
View File

@@ -0,0 +1,15 @@
# Control D config
`ctrld` can build a Control D config and run with the specific resolver data.
For example:
```shell
ctrld run --cd p2
```
Above command will fetch the `p2` resolver data from Control D API and use that data for running `ctrld`:
- The resolver `doh` endpoint will be used as the primary upstream.
- The resolver `exclude` list will be used to create a rule policy which will steer them to the default OS resolver.
```

View File

@@ -0,0 +1,90 @@
package controld
import (
"bytes"
"encoding/json"
"errors"
"fmt"
"net/http"
"time"
)
const resolverDataURL = "https://api.controld.com/utility"
// ResolverConfig represents Control D resolver data.
type ResolverConfig struct {
V4 []string `json:"v4"`
V6 []string `json:"v6"`
DOH string `json:"doh"`
Exclude []string `json:"exclude"`
}
func (r *ResolverConfig) IP(v6 bool) string {
ip4 := r.v4()
ip6 := r.v6()
if v6 && ip6 != "" {
return ip6
}
return ip4
}
func (r *ResolverConfig) v4() string {
for _, ip := range r.V4 {
return ip
}
return ""
}
func (r *ResolverConfig) v6() string {
for _, ip := range r.V6 {
return ip
}
return ""
}
type utilityResponse struct {
Success bool `json:"success"`
Body struct {
Resolver ResolverConfig `json:"resolver"`
} `json:"body"`
}
type utilityErrorResponse struct {
Error struct {
Message string `json:"message"`
} `json:"error"`
}
type utilityRequest struct {
UID string `json:"uid"`
}
// FetchResolverConfig fetch Control D config for given uid.
func FetchResolverConfig(uid string) (*ResolverConfig, error) {
body, _ := json.Marshal(utilityRequest{UID: uid})
req, err := http.NewRequest("POST", resolverDataURL, bytes.NewReader(body))
if err != nil {
return nil, fmt.Errorf("http.NewRequest: %w", err)
}
req.Header.Add("Content-Type", "application/json")
client := http.Client{Timeout: 5 * time.Second}
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("client.Do: %w", err)
}
defer resp.Body.Close()
d := json.NewDecoder(resp.Body)
if resp.StatusCode != http.StatusOK {
errResp := &utilityErrorResponse{}
if err := d.Decode(errResp); err != nil {
return nil, err
}
return nil, errors.New(errResp.Error.Message)
}
ur := &utilityResponse{}
if err := d.Decode(ur); err != nil {
return nil, err
}
return &ur.Body.Resolver, nil
}

View File

@@ -0,0 +1,33 @@
//go:build controld
package controld
import (
"testing"
"github.com/stretchr/testify/assert"
)
const utilityURL = "https://api.controld.com/utility"
func TestFetchResolverConfig(t *testing.T) {
tests := []struct {
name string
uid string
wantErr bool
}{
{"valid", "p2", false},
{"invalid uid", "abcd1234", true},
}
for _, tc := range tests {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
got, err := FetchResolverConfig(tc.uid)
assert.False(t, (err != nil) != tc.wantErr)
if !tc.wantErr {
assert.NotEmpty(t, got.DOH)
}
})
}
}