mirror of
https://github.com/Control-D-Inc/ctrld.git
synced 2026-02-03 22:18:39 +00:00
all: add starting service with Control D config
This commit is contained in:
committed by
Cuong Manh Le
parent
ec72af1916
commit
114ef9aad6
@@ -17,6 +17,7 @@ import (
|
||||
"github.com/spf13/viper"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld"
|
||||
"github.com/Control-D-Inc/ctrld/internal/controld"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -61,15 +62,15 @@ func initCLI() {
|
||||
log.Fatal("Cannot run in daemon mode. Please install a Windows service.")
|
||||
}
|
||||
|
||||
noConfigStart := isNoConfigStart(cmd)
|
||||
|
||||
noConfigStart := isNoConfigStart(cmd) && cdUID != ""
|
||||
writeDefaultConfig := !noConfigStart && configBase64 == ""
|
||||
configs := []struct {
|
||||
name string
|
||||
written bool
|
||||
}{
|
||||
// For compatibility, we check for config.toml first, but only read it if exists.
|
||||
{"config", false},
|
||||
{"ctrld", !noConfigStart && configBase64 == ""},
|
||||
{"ctrld", writeDefaultConfig},
|
||||
}
|
||||
for _, config := range configs {
|
||||
ctrld.SetConfigName(v, config.name)
|
||||
@@ -81,6 +82,7 @@ func initCLI() {
|
||||
|
||||
readBase64Config()
|
||||
processNoConfigFlags(noConfigStart)
|
||||
processCDFlags()
|
||||
if err := v.Unmarshal(&cfg); err != nil {
|
||||
log.Fatalf("failed to unmarshal config: %v", err)
|
||||
}
|
||||
@@ -138,6 +140,7 @@ func initCLI() {
|
||||
runCmd.Flags().StringSliceVarP(&domains, "domains", "", nil, "list of domain to apply in a split DNS policy")
|
||||
runCmd.Flags().StringVarP(&logPath, "log", "", "", "path to log file")
|
||||
runCmd.Flags().IntVarP(&cacheSize, "cache_size", "", 0, "Enable cache with size items")
|
||||
runCmd.Flags().StringVarP(&cdUID, "cd", "", "", "Control D resolver uid")
|
||||
|
||||
rootCmd.AddCommand(runCmd)
|
||||
|
||||
@@ -183,6 +186,7 @@ func initCLI() {
|
||||
startCmd.Flags().StringSliceVarP(&domains, "domains", "", nil, "list of domain to apply in a split DNS policy")
|
||||
startCmd.Flags().StringVarP(&logPath, "log", "", "", "path to log file")
|
||||
startCmd.Flags().IntVarP(&cacheSize, "cache_size", "", 0, "Enable cache with size items")
|
||||
startCmd.Flags().StringVarP(&cdUID, "cd", "", "", "Control D resolver uid")
|
||||
|
||||
stopCmd := &cobra.Command{
|
||||
Use: "stop",
|
||||
@@ -308,21 +312,24 @@ func writeConfigFile() {
|
||||
}
|
||||
}
|
||||
|
||||
func readConfigFile(configWritten bool) bool {
|
||||
func readConfigFile(writeDefaultConfig bool) bool {
|
||||
// If err == nil, there's a config supplied via `--config`, no default config written.
|
||||
err := v.ReadInConfig()
|
||||
if err == nil {
|
||||
return true
|
||||
}
|
||||
|
||||
if !configWritten {
|
||||
if !writeDefaultConfig {
|
||||
return false
|
||||
}
|
||||
|
||||
// If error is viper.ConfigFileNotFoundError, write default config.
|
||||
if _, ok := err.(viper.ConfigFileNotFoundError); ok {
|
||||
writeConfigFile()
|
||||
defaultConfigWritten = true
|
||||
return false
|
||||
}
|
||||
// Otherwise, report fatal error and exit.
|
||||
log.Fatalf("failed to decode config file: %v", err)
|
||||
return false
|
||||
}
|
||||
@@ -347,21 +354,7 @@ func processNoConfigFlags(noConfigStart bool) {
|
||||
if listenAddress == "" || primaryUpstream == "" {
|
||||
log.Fatal(`"listen" and "primary_upstream" flags must be set in no config mode`)
|
||||
}
|
||||
host, portStr, err := net.SplitHostPort(listenAddress)
|
||||
if err != nil {
|
||||
log.Fatalf("invalid listener address: %v", err)
|
||||
}
|
||||
port, err := strconv.Atoi(portStr)
|
||||
if err != nil {
|
||||
log.Fatalf("invalid port number: %v", err)
|
||||
}
|
||||
lc := &ctrld.ListenerConfig{
|
||||
IP: host,
|
||||
Port: port,
|
||||
}
|
||||
v.Set("listener", map[string]*ctrld.ListenerConfig{
|
||||
"0": lc,
|
||||
})
|
||||
processListenFlag()
|
||||
|
||||
upstream := map[string]*ctrld.UpstreamConfig{
|
||||
"0": {
|
||||
@@ -380,10 +373,67 @@ func processNoConfigFlags(noConfigStart bool) {
|
||||
for _, domain := range domains {
|
||||
rules = append(rules, ctrld.Rule{domain: []string{"upstream.1"}})
|
||||
}
|
||||
lc := v.Get("listener").(map[string]*ctrld.ListenerConfig)["0"]
|
||||
lc.Policy = &ctrld.ListenerPolicyConfig{Name: "My Policy", Rules: rules}
|
||||
}
|
||||
v.Set("upstream", upstream)
|
||||
|
||||
processLogAndCacheFlags()
|
||||
}
|
||||
|
||||
func processCDFlags() {
|
||||
if cdUID == "" {
|
||||
return
|
||||
}
|
||||
resolverConfig, err := controld.FetchResolverConfig(cdUID)
|
||||
if err != nil {
|
||||
log.Fatalf("failed to fetch resolver config: %v", err)
|
||||
}
|
||||
|
||||
upstream := map[string]*ctrld.UpstreamConfig{
|
||||
"0": {
|
||||
BootstrapIP: resolverConfig.IP(supportsIPv6()),
|
||||
Name: resolverConfig.DOH,
|
||||
Endpoint: resolverConfig.DOH,
|
||||
Type: ctrld.ResolverTypeDOH,
|
||||
},
|
||||
}
|
||||
v.Set("upstream", upstream)
|
||||
|
||||
processListenFlag()
|
||||
|
||||
rules := make([]ctrld.Rule, 0, len(resolverConfig.Exclude))
|
||||
for _, domain := range resolverConfig.Exclude {
|
||||
rules = append(rules, ctrld.Rule{domain: []string{}})
|
||||
}
|
||||
lc := v.Get("listener").(map[string]*ctrld.ListenerConfig)["0"]
|
||||
lc.Policy = &ctrld.ListenerPolicyConfig{Name: "My Policy", Rules: rules}
|
||||
|
||||
processLogAndCacheFlags()
|
||||
}
|
||||
|
||||
func processListenFlag() {
|
||||
if listenAddress == "" {
|
||||
return
|
||||
}
|
||||
host, portStr, err := net.SplitHostPort(listenAddress)
|
||||
if err != nil {
|
||||
log.Fatalf("invalid listener address: %v", err)
|
||||
}
|
||||
port, err := strconv.Atoi(portStr)
|
||||
if err != nil {
|
||||
log.Fatalf("invalid port number: %v", err)
|
||||
}
|
||||
lc := &ctrld.ListenerConfig{
|
||||
IP: host,
|
||||
Port: port,
|
||||
}
|
||||
v.Set("listener", map[string]*ctrld.ListenerConfig{
|
||||
"0": lc,
|
||||
})
|
||||
}
|
||||
|
||||
func processLogAndCacheFlags() {
|
||||
sc := ctrld.ServiceConfig{}
|
||||
if logPath != "" {
|
||||
sc.LogLevel = "debug"
|
||||
|
||||
@@ -143,6 +143,10 @@ func (p *prog) proxy(ctx context.Context, upstreams []string, failoverRcodes []i
|
||||
}
|
||||
}
|
||||
upstreamConfigs := p.upstreamConfigsFromUpstreamNumbers(upstreams)
|
||||
if len(upstreamConfigs) == 0 {
|
||||
upstreamConfigs = []*ctrld.UpstreamConfig{osUpstreamConfig}
|
||||
upstreams = []string{"upstream.os"}
|
||||
}
|
||||
resolve := func(n int, upstreamConfig *ctrld.UpstreamConfig, msg *dns.Msg) *dns.Msg {
|
||||
ctrld.Log(ctx, proxyLog.Debug(), "sending query to %s: %s", upstreams[n], upstreamConfig.Name)
|
||||
dnsResolver, err := ctrld.NewResolver(upstreamConfig)
|
||||
@@ -204,9 +208,6 @@ func (p *prog) upstreamConfigsFromUpstreamNumbers(upstreams []string) []*ctrld.U
|
||||
upstreamNum := strings.TrimPrefix(upstream, "upstream.")
|
||||
upstreamConfigs = append(upstreamConfigs, p.cfg.Upstream[upstreamNum])
|
||||
}
|
||||
if len(upstreamConfigs) == 0 {
|
||||
upstreamConfigs = []*ctrld.UpstreamConfig{osUpstreamConfig}
|
||||
}
|
||||
return upstreamConfigs
|
||||
}
|
||||
|
||||
|
||||
@@ -31,6 +31,8 @@ var (
|
||||
rootLogger = zerolog.New(io.Discard)
|
||||
mainLog = rootLogger
|
||||
proxyLog = rootLogger
|
||||
|
||||
cdUID string
|
||||
)
|
||||
|
||||
func main() {
|
||||
|
||||
@@ -25,6 +25,8 @@ Usage:
|
||||
|
||||
Flags:
|
||||
--base64_config string base64 encoded config
|
||||
--cache_size int Enable cache with size items
|
||||
--cd string Control D resolver uid
|
||||
-c, --config string Path to config file
|
||||
-d, --daemon Run as daemon
|
||||
--domains strings list of domain to apply in a split DNS policy
|
||||
|
||||
15
docs/controld_config.md
Normal file
15
docs/controld_config.md
Normal file
@@ -0,0 +1,15 @@
|
||||
# Control D config
|
||||
|
||||
`ctrld` can build a Control D config and run with the specific resolver data.
|
||||
|
||||
For example:
|
||||
|
||||
```shell
|
||||
ctrld run --cd p2
|
||||
```
|
||||
|
||||
Above command will fetch the `p2` resolver data from Control D API and use that data for running `ctrld`:
|
||||
|
||||
- The resolver `doh` endpoint will be used as the primary upstream.
|
||||
- The resolver `exclude` list will be used to create a rule policy which will steer them to the default OS resolver.
|
||||
```
|
||||
90
internal/controld/config.go
Normal file
90
internal/controld/config.go
Normal file
@@ -0,0 +1,90 @@
|
||||
package controld
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
const resolverDataURL = "https://api.controld.com/utility"
|
||||
|
||||
// ResolverConfig represents Control D resolver data.
|
||||
type ResolverConfig struct {
|
||||
V4 []string `json:"v4"`
|
||||
V6 []string `json:"v6"`
|
||||
DOH string `json:"doh"`
|
||||
Exclude []string `json:"exclude"`
|
||||
}
|
||||
|
||||
func (r *ResolverConfig) IP(v6 bool) string {
|
||||
ip4 := r.v4()
|
||||
ip6 := r.v6()
|
||||
if v6 && ip6 != "" {
|
||||
return ip6
|
||||
}
|
||||
return ip4
|
||||
}
|
||||
|
||||
func (r *ResolverConfig) v4() string {
|
||||
for _, ip := range r.V4 {
|
||||
return ip
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (r *ResolverConfig) v6() string {
|
||||
for _, ip := range r.V6 {
|
||||
return ip
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
type utilityResponse struct {
|
||||
Success bool `json:"success"`
|
||||
Body struct {
|
||||
Resolver ResolverConfig `json:"resolver"`
|
||||
} `json:"body"`
|
||||
}
|
||||
|
||||
type utilityErrorResponse struct {
|
||||
Error struct {
|
||||
Message string `json:"message"`
|
||||
} `json:"error"`
|
||||
}
|
||||
|
||||
type utilityRequest struct {
|
||||
UID string `json:"uid"`
|
||||
}
|
||||
|
||||
// FetchResolverConfig fetch Control D config for given uid.
|
||||
func FetchResolverConfig(uid string) (*ResolverConfig, error) {
|
||||
body, _ := json.Marshal(utilityRequest{UID: uid})
|
||||
req, err := http.NewRequest("POST", resolverDataURL, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("http.NewRequest: %w", err)
|
||||
}
|
||||
req.Header.Add("Content-Type", "application/json")
|
||||
client := http.Client{Timeout: 5 * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("client.Do: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
d := json.NewDecoder(resp.Body)
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
errResp := &utilityErrorResponse{}
|
||||
if err := d.Decode(errResp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return nil, errors.New(errResp.Error.Message)
|
||||
}
|
||||
|
||||
ur := &utilityResponse{}
|
||||
if err := d.Decode(ur); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &ur.Body.Resolver, nil
|
||||
}
|
||||
33
internal/controld/config_test.go
Normal file
33
internal/controld/config_test.go
Normal file
@@ -0,0 +1,33 @@
|
||||
//go:build controld
|
||||
|
||||
package controld
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
const utilityURL = "https://api.controld.com/utility"
|
||||
|
||||
func TestFetchResolverConfig(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
uid string
|
||||
wantErr bool
|
||||
}{
|
||||
{"valid", "p2", false},
|
||||
{"invalid uid", "abcd1234", true},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
got, err := FetchResolverConfig(tc.uid)
|
||||
assert.False(t, (err != nil) != tc.wantErr)
|
||||
if !tc.wantErr {
|
||||
assert.NotEmpty(t, got.DOH)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user