all: add CLI flags for no config start

This commit adds the ability to start `ctrld` without config file. All
necessary information can be provided via command line flags, either in
base64 encoded config or launch arguments.
This commit is contained in:
Cuong Manh Le
2022-12-21 19:08:19 +07:00
committed by Cuong Manh Le
parent 30fefe7ab9
commit b93970ccfd
6 changed files with 205 additions and 30 deletions

View File

@@ -1,11 +1,15 @@
package main
import (
"bytes"
"encoding/base64"
"fmt"
"log"
"net"
"os"
"os/exec"
"runtime"
"strconv"
"github.com/go-playground/validator/v10"
"github.com/kardianos/service"
@@ -38,6 +42,7 @@ func initCLI() {
`verbose log output, "-v" means query logging enabled, "-vv" means debug level logging enabled`,
)
basicModeFlags := []string{"listen", "primary_upstream", "secondary_upstream", "domains", "log"}
runCmd := &cobra.Command{
Use: "run",
Short: "Run the DNS proxy server",
@@ -49,14 +54,18 @@ func initCLI() {
if configPath != "" {
v.SetConfigFile(configPath)
}
if err := v.ReadInConfig(); err != nil {
if _, ok := err.(viper.ConfigFileNotFoundError); ok {
writeConfigFile()
defaultConfigWritten = true
} else {
log.Fatalf("failed to decode config file: %v", err)
noConfigStart := func() bool {
for _, flagName := range basicModeFlags {
if cmd.Flags().Lookup(flagName).Changed {
return true
}
}
}
return false
}()
readConfigFile(!noConfigStart && configBase64 == "")
readBase64Config()
processNoConfigFlags(noConfigStart)
if err := v.Unmarshal(&cfg); err != nil {
log.Fatalf("failed to unmarshal config: %v", err)
}
@@ -106,6 +115,12 @@ func initCLI() {
}
runCmd.Flags().BoolVarP(&daemon, "daemon", "d", false, "Run as daemon")
runCmd.Flags().StringVarP(&configPath, "config", "c", "", "Path to config file")
runCmd.Flags().StringVarP(&configBase64, "base64_config", "", "", "base64 encoded config")
runCmd.Flags().StringVarP(&listenAddress, "listen", "", "", "listener address and port, in format: address:port")
runCmd.Flags().StringVarP(&primaryUpstream, "primary_upstream", "", "", "primary upstream endpoint")
runCmd.Flags().StringVarP(&secondaryUpstream, "secondary_upstream", "", "", "secondary upstream endpoint")
runCmd.Flags().StringSliceVarP(&domains, "domains", "", nil, "list of domain to apply in a split DNS policy")
runCmd.Flags().StringVarP(&logPath, "log", "", "", "path to log file")
rootCmd.AddCommand(runCmd)
@@ -125,3 +140,78 @@ func writeConfigFile() {
log.Printf("failed to write config file: %v\n", err)
}
}
func readConfigFile(configWritten bool) {
err := v.ReadInConfig()
if err == nil || !configWritten {
return
}
if _, ok := err.(viper.ConfigFileNotFoundError); ok {
writeConfigFile()
defaultConfigWritten = true
return
}
log.Fatalf("failed to decode config file: %v", err)
}
func readBase64Config() {
if configBase64 == "" {
return
}
configStr, err := base64.StdEncoding.DecodeString(configBase64)
if err != nil {
log.Fatalf("invalid base64 config: %v", err)
}
if err := v.ReadConfig(bytes.NewReader(configStr)); err != nil {
log.Fatalf("failed to read base64 config: %v", err)
}
}
func processNoConfigFlags(noConfigStart bool) {
if !noConfigStart {
return
}
if listenAddress == "" || primaryUpstream == "" {
log.Fatal(`"listen" and "primary_upstream" flags must be set in no config mode`)
}
host, portStr, err := net.SplitHostPort(listenAddress)
if err != nil {
log.Fatalf("invalid listener address: %v", err)
}
port, err := strconv.Atoi(portStr)
if err != nil {
log.Fatalf("invalid port number: %v", err)
}
lc := &ctrld.ListenerConfig{
IP: host,
Port: port,
}
v.Set("listener", map[string]*ctrld.ListenerConfig{
"0": lc,
})
upstream := map[string]*ctrld.UpstreamConfig{
"0": {
Name: primaryUpstream,
Endpoint: primaryUpstream,
Type: ctrld.ResolverTypeDOH,
},
}
if secondaryUpstream != "" {
upstream["1"] = &ctrld.UpstreamConfig{
Name: secondaryUpstream,
Endpoint: secondaryUpstream,
Type: ctrld.ResolverTypeLegacy,
}
rules := make([]ctrld.Rule, 0, len(domains))
for _, domain := range domains {
rules = append(rules, ctrld.Rule{domain: []string{"upstream.1"}})
}
lc.Policy = &ctrld.ListenerPolicyConfig{Name: "My Policy", Rules: rules}
}
v.Set("upstream", upstream)
if logPath != "" {
v.Set("service", ctrld.ServiceConfig{LogLevel: "debug", LogPath: logPath})
}
}

View File

@@ -12,10 +12,16 @@ import (
)
var (
configPath string
daemon bool
cfg ctrld.Config
verbose int
configPath string
configBase64 string
daemon bool
listenAddress string
primaryUpstream string
secondaryUpstream string
domains []string
logPath string
cfg ctrld.Config
verbose int
bootstrapDNS = "76.76.2.0"