mirror of
https://github.com/Control-D-Inc/ctrld.git
synced 2026-05-15 00:50:25 +02:00
all: add starting service with Control D config
This commit is contained in:
committed by
Cuong Manh Le
parent
ec72af1916
commit
114ef9aad6
+70
-20
@@ -17,6 +17,7 @@ import (
|
||||
"github.com/spf13/viper"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld"
|
||||
"github.com/Control-D-Inc/ctrld/internal/controld"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -61,15 +62,15 @@ func initCLI() {
|
||||
log.Fatal("Cannot run in daemon mode. Please install a Windows service.")
|
||||
}
|
||||
|
||||
noConfigStart := isNoConfigStart(cmd)
|
||||
|
||||
noConfigStart := isNoConfigStart(cmd) && cdUID != ""
|
||||
writeDefaultConfig := !noConfigStart && configBase64 == ""
|
||||
configs := []struct {
|
||||
name string
|
||||
written bool
|
||||
}{
|
||||
// For compatibility, we check for config.toml first, but only read it if exists.
|
||||
{"config", false},
|
||||
{"ctrld", !noConfigStart && configBase64 == ""},
|
||||
{"ctrld", writeDefaultConfig},
|
||||
}
|
||||
for _, config := range configs {
|
||||
ctrld.SetConfigName(v, config.name)
|
||||
@@ -81,6 +82,7 @@ func initCLI() {
|
||||
|
||||
readBase64Config()
|
||||
processNoConfigFlags(noConfigStart)
|
||||
processCDFlags()
|
||||
if err := v.Unmarshal(&cfg); err != nil {
|
||||
log.Fatalf("failed to unmarshal config: %v", err)
|
||||
}
|
||||
@@ -138,6 +140,7 @@ func initCLI() {
|
||||
runCmd.Flags().StringSliceVarP(&domains, "domains", "", nil, "list of domain to apply in a split DNS policy")
|
||||
runCmd.Flags().StringVarP(&logPath, "log", "", "", "path to log file")
|
||||
runCmd.Flags().IntVarP(&cacheSize, "cache_size", "", 0, "Enable cache with size items")
|
||||
runCmd.Flags().StringVarP(&cdUID, "cd", "", "", "Control D resolver uid")
|
||||
|
||||
rootCmd.AddCommand(runCmd)
|
||||
|
||||
@@ -183,6 +186,7 @@ func initCLI() {
|
||||
startCmd.Flags().StringSliceVarP(&domains, "domains", "", nil, "list of domain to apply in a split DNS policy")
|
||||
startCmd.Flags().StringVarP(&logPath, "log", "", "", "path to log file")
|
||||
startCmd.Flags().IntVarP(&cacheSize, "cache_size", "", 0, "Enable cache with size items")
|
||||
startCmd.Flags().StringVarP(&cdUID, "cd", "", "", "Control D resolver uid")
|
||||
|
||||
stopCmd := &cobra.Command{
|
||||
Use: "stop",
|
||||
@@ -308,21 +312,24 @@ func writeConfigFile() {
|
||||
}
|
||||
}
|
||||
|
||||
func readConfigFile(configWritten bool) bool {
|
||||
func readConfigFile(writeDefaultConfig bool) bool {
|
||||
// If err == nil, there's a config supplied via `--config`, no default config written.
|
||||
err := v.ReadInConfig()
|
||||
if err == nil {
|
||||
return true
|
||||
}
|
||||
|
||||
if !configWritten {
|
||||
if !writeDefaultConfig {
|
||||
return false
|
||||
}
|
||||
|
||||
// If error is viper.ConfigFileNotFoundError, write default config.
|
||||
if _, ok := err.(viper.ConfigFileNotFoundError); ok {
|
||||
writeConfigFile()
|
||||
defaultConfigWritten = true
|
||||
return false
|
||||
}
|
||||
// Otherwise, report fatal error and exit.
|
||||
log.Fatalf("failed to decode config file: %v", err)
|
||||
return false
|
||||
}
|
||||
@@ -347,21 +354,7 @@ func processNoConfigFlags(noConfigStart bool) {
|
||||
if listenAddress == "" || primaryUpstream == "" {
|
||||
log.Fatal(`"listen" and "primary_upstream" flags must be set in no config mode`)
|
||||
}
|
||||
host, portStr, err := net.SplitHostPort(listenAddress)
|
||||
if err != nil {
|
||||
log.Fatalf("invalid listener address: %v", err)
|
||||
}
|
||||
port, err := strconv.Atoi(portStr)
|
||||
if err != nil {
|
||||
log.Fatalf("invalid port number: %v", err)
|
||||
}
|
||||
lc := &ctrld.ListenerConfig{
|
||||
IP: host,
|
||||
Port: port,
|
||||
}
|
||||
v.Set("listener", map[string]*ctrld.ListenerConfig{
|
||||
"0": lc,
|
||||
})
|
||||
processListenFlag()
|
||||
|
||||
upstream := map[string]*ctrld.UpstreamConfig{
|
||||
"0": {
|
||||
@@ -380,10 +373,67 @@ func processNoConfigFlags(noConfigStart bool) {
|
||||
for _, domain := range domains {
|
||||
rules = append(rules, ctrld.Rule{domain: []string{"upstream.1"}})
|
||||
}
|
||||
lc := v.Get("listener").(map[string]*ctrld.ListenerConfig)["0"]
|
||||
lc.Policy = &ctrld.ListenerPolicyConfig{Name: "My Policy", Rules: rules}
|
||||
}
|
||||
v.Set("upstream", upstream)
|
||||
|
||||
processLogAndCacheFlags()
|
||||
}
|
||||
|
||||
func processCDFlags() {
|
||||
if cdUID == "" {
|
||||
return
|
||||
}
|
||||
resolverConfig, err := controld.FetchResolverConfig(cdUID)
|
||||
if err != nil {
|
||||
log.Fatalf("failed to fetch resolver config: %v", err)
|
||||
}
|
||||
|
||||
upstream := map[string]*ctrld.UpstreamConfig{
|
||||
"0": {
|
||||
BootstrapIP: resolverConfig.IP(supportsIPv6()),
|
||||
Name: resolverConfig.DOH,
|
||||
Endpoint: resolverConfig.DOH,
|
||||
Type: ctrld.ResolverTypeDOH,
|
||||
},
|
||||
}
|
||||
v.Set("upstream", upstream)
|
||||
|
||||
processListenFlag()
|
||||
|
||||
rules := make([]ctrld.Rule, 0, len(resolverConfig.Exclude))
|
||||
for _, domain := range resolverConfig.Exclude {
|
||||
rules = append(rules, ctrld.Rule{domain: []string{}})
|
||||
}
|
||||
lc := v.Get("listener").(map[string]*ctrld.ListenerConfig)["0"]
|
||||
lc.Policy = &ctrld.ListenerPolicyConfig{Name: "My Policy", Rules: rules}
|
||||
|
||||
processLogAndCacheFlags()
|
||||
}
|
||||
|
||||
func processListenFlag() {
|
||||
if listenAddress == "" {
|
||||
return
|
||||
}
|
||||
host, portStr, err := net.SplitHostPort(listenAddress)
|
||||
if err != nil {
|
||||
log.Fatalf("invalid listener address: %v", err)
|
||||
}
|
||||
port, err := strconv.Atoi(portStr)
|
||||
if err != nil {
|
||||
log.Fatalf("invalid port number: %v", err)
|
||||
}
|
||||
lc := &ctrld.ListenerConfig{
|
||||
IP: host,
|
||||
Port: port,
|
||||
}
|
||||
v.Set("listener", map[string]*ctrld.ListenerConfig{
|
||||
"0": lc,
|
||||
})
|
||||
}
|
||||
|
||||
func processLogAndCacheFlags() {
|
||||
sc := ctrld.ServiceConfig{}
|
||||
if logPath != "" {
|
||||
sc.LogLevel = "debug"
|
||||
|
||||
@@ -143,6 +143,10 @@ func (p *prog) proxy(ctx context.Context, upstreams []string, failoverRcodes []i
|
||||
}
|
||||
}
|
||||
upstreamConfigs := p.upstreamConfigsFromUpstreamNumbers(upstreams)
|
||||
if len(upstreamConfigs) == 0 {
|
||||
upstreamConfigs = []*ctrld.UpstreamConfig{osUpstreamConfig}
|
||||
upstreams = []string{"upstream.os"}
|
||||
}
|
||||
resolve := func(n int, upstreamConfig *ctrld.UpstreamConfig, msg *dns.Msg) *dns.Msg {
|
||||
ctrld.Log(ctx, proxyLog.Debug(), "sending query to %s: %s", upstreams[n], upstreamConfig.Name)
|
||||
dnsResolver, err := ctrld.NewResolver(upstreamConfig)
|
||||
@@ -204,9 +208,6 @@ func (p *prog) upstreamConfigsFromUpstreamNumbers(upstreams []string) []*ctrld.U
|
||||
upstreamNum := strings.TrimPrefix(upstream, "upstream.")
|
||||
upstreamConfigs = append(upstreamConfigs, p.cfg.Upstream[upstreamNum])
|
||||
}
|
||||
if len(upstreamConfigs) == 0 {
|
||||
upstreamConfigs = []*ctrld.UpstreamConfig{osUpstreamConfig}
|
||||
}
|
||||
return upstreamConfigs
|
||||
}
|
||||
|
||||
|
||||
@@ -31,6 +31,8 @@ var (
|
||||
rootLogger = zerolog.New(io.Discard)
|
||||
mainLog = rootLogger
|
||||
proxyLog = rootLogger
|
||||
|
||||
cdUID string
|
||||
)
|
||||
|
||||
func main() {
|
||||
|
||||
Reference in New Issue
Block a user