diff --git a/cmd/ctrld/dns_proxy.go b/cmd/ctrld/dns_proxy.go index 9f9fa30..366fafb 100644 --- a/cmd/ctrld/dns_proxy.go +++ b/cmd/ctrld/dns_proxy.go @@ -47,6 +47,8 @@ func (p *prog) serveDNS(listenerNum string) error { failoverRcodes = listenerConfig.Policy.FailoverRcodeNumbers } handler := dns.HandlerFunc(func(w dns.ResponseWriter, m *dns.Msg) { + p.sema.acquire() + defer p.sema.release() q := m.Question[0] domain := canonicalName(q.Name) reqId := requestID() @@ -60,7 +62,6 @@ func (p *prog) serveDNS(listenerNum string) error { if !matched && listenerConfig.Restricted { answer = new(dns.Msg) answer.SetRcode(m, dns.RcodeRefused) - } else { answer = p.proxy(ctx, upstreams, failoverRcodes, m) rtt := time.Since(t) diff --git a/cmd/ctrld/prog.go b/cmd/ctrld/prog.go index 807d22c..5ec372c 100644 --- a/cmd/ctrld/prog.go +++ b/cmd/ctrld/prog.go @@ -17,6 +17,8 @@ import ( "github.com/Control-D-Inc/ctrld/internal/router" ) +const defaultSemaphoreCap = 256 + var logf = func(format string, args ...any) { mainLog.Debug().Msgf(format, args...) } @@ -36,6 +38,7 @@ type prog struct { cfg *ctrld.Config cache dnscache.Cacher + sema semaphore } func (p *prog) Start(s service.Service) error { @@ -56,6 +59,15 @@ func (p *prog) run() { p.cache = cacher } } + p.sema = &chanSemaphore{ready: make(chan struct{}, defaultSemaphoreCap)} + if mcr := p.cfg.Service.MaxConcurrentRequests; mcr != nil { + n := *mcr + if n == 0 { + p.sema = &noopSemaphore{} + } else { + p.sema = &chanSemaphore{ready: make(chan struct{}, n)} + } + } var wg sync.WaitGroup wg.Add(len(p.cfg.Listener)) diff --git a/cmd/ctrld/sema.go b/cmd/ctrld/sema.go new file mode 100644 index 0000000..8faa9d2 --- /dev/null +++ b/cmd/ctrld/sema.go @@ -0,0 +1,24 @@ +package main + +type semaphore interface { + acquire() + release() +} + +type noopSemaphore struct{} + +func (n noopSemaphore) acquire() {} + +func (n noopSemaphore) release() {} + +type chanSemaphore struct { + ready chan struct{} +} + +func (c *chanSemaphore) acquire() { + c.ready <- struct{}{} +} + +func (c *chanSemaphore) release() { + <-c.ready +} diff --git a/config.go b/config.go index bdd335b..6fa54e4 100644 --- a/config.go +++ b/config.go @@ -122,14 +122,15 @@ func (c *Config) HasUpstreamSendClientInfo() bool { // ServiceConfig specifies the general ctrld config. type ServiceConfig struct { - LogLevel string `mapstructure:"log_level" toml:"log_level,omitempty"` - LogPath string `mapstructure:"log_path" toml:"log_path,omitempty"` - CacheEnable bool `mapstructure:"cache_enable" toml:"cache_enable,omitempty"` - CacheSize int `mapstructure:"cache_size" toml:"cache_size,omitempty"` - CacheTTLOverride int `mapstructure:"cache_ttl_override" toml:"cache_ttl_override,omitempty"` - CacheServeStale bool `mapstructure:"cache_serve_stale" toml:"cache_serve_stale,omitempty"` - Daemon bool `mapstructure:"-" toml:"-"` - AllocateIP bool `mapstructure:"-" toml:"-"` + LogLevel string `mapstructure:"log_level" toml:"log_level,omitempty"` + LogPath string `mapstructure:"log_path" toml:"log_path,omitempty"` + CacheEnable bool `mapstructure:"cache_enable" toml:"cache_enable,omitempty"` + CacheSize int `mapstructure:"cache_size" toml:"cache_size,omitempty"` + CacheTTLOverride int `mapstructure:"cache_ttl_override" toml:"cache_ttl_override,omitempty"` + CacheServeStale bool `mapstructure:"cache_serve_stale" toml:"cache_serve_stale,omitempty"` + MaxConcurrentRequests *int `mapstructure:"max_concurrent_requests" toml:"max_concurrent_requests,omitempty" validate:"omitempty,gte=0"` + Daemon bool `mapstructure:"-" toml:"-"` + AllocateIP bool `mapstructure:"-" toml:"-"` } // NetworkConfig specifies configuration for networks where ctrld will handle requests. diff --git a/config_test.go b/config_test.go index ddbc97b..27f9d40 100644 --- a/config_test.go +++ b/config_test.go @@ -75,6 +75,7 @@ func TestConfigValidation(t *testing.T) { {"os upstream", configWithOsUpstream(t), false}, {"invalid rules", configWithInvalidRules(t), true}, {"invalid dns rcodes", configWithInvalidRcodes(t), true}, + {"invalid max concurrent requests", configWithInvalidMaxConcurrentRequests(t), true}, } for _, tc := range tests { @@ -176,3 +177,10 @@ func configWithInvalidRcodes(t *testing.T) *ctrld.Config { } return cfg } + +func configWithInvalidMaxConcurrentRequests(t *testing.T) *ctrld.Config { + cfg := defaultConfig(t) + n := -1 + cfg.Service.MaxConcurrentRequests = &n + return cfg +} diff --git a/docs/config.md b/docs/config.md index f699b4b..8af8e40 100644 --- a/docs/config.md +++ b/docs/config.md @@ -157,6 +157,14 @@ stale cached records (regardless of their TTLs) until upstream comes online. - Required: no - Default: false +### max_concurrent_requests +The number of concurrent requests that will be handled, must be a non-negative integer. +Tweaking this value depends on the capacity of your system. + +- Type: number +- Required: no +- Default: 256 + ## Upstream The `[upstream]` section specifies the DNS upstream servers that `ctrld` will forward DNS requests to.