diff --git a/config.go b/config.go index 0975d22..ad55dad 100644 --- a/config.go +++ b/config.go @@ -82,6 +82,10 @@ const ( endpointPrefixQUIC = "quic://" endpointPrefixH3 = "h3://" endpointPrefixSdns = "sdns://" + + rebootstrapNotStarted = 0 + rebootstrapStarted = 1 + rebootstrapInProgress = 2 ) var ( @@ -264,7 +268,7 @@ type UpstreamConfig struct { Discoverable *bool `mapstructure:"discoverable" toml:"discoverable"` g singleflight.Group - rebootstrap atomic.Bool + rebootstrap atomic.Int64 bootstrapIPs []string bootstrapIPs4 []string bootstrapIPs6 []string @@ -497,7 +501,7 @@ func (uc *UpstreamConfig) ReBootstrap() { return } _, _, _ = uc.g.Do("ReBootstrap", func() (any, error) { - if uc.rebootstrap.CompareAndSwap(false, true) { + if uc.rebootstrap.CompareAndSwap(rebootstrapNotStarted, rebootstrapStarted) { ProxyLogger.Load().Debug().Msgf("re-bootstrapping upstream ip for %v", uc) } return true, nil @@ -542,8 +546,9 @@ func (uc *UpstreamConfig) ensureSetupTransport() { uc.transportOnce.Do(func() { uc.SetupTransport() }) - if uc.rebootstrap.CompareAndSwap(true, false) { + if uc.rebootstrap.CompareAndSwap(rebootstrapStarted, rebootstrapInProgress) { uc.SetupTransport() + uc.rebootstrap.Store(rebootstrapNotStarted) } } diff --git a/config_internal_test.go b/config_internal_test.go index b37e982..ca2b381 100644 --- a/config_internal_test.go +++ b/config_internal_test.go @@ -2,6 +2,7 @@ package ctrld import ( "net/url" + "sync" "testing" "github.com/stretchr/testify/assert" @@ -505,6 +506,50 @@ func TestUpstreamConfig_IsDiscoverable(t *testing.T) { } } +func TestRebootstrapRace(t *testing.T) { + uc := &UpstreamConfig{ + Name: "test-doh", + Type: ResolverTypeDOH, + Endpoint: "https://example.com/dns-query", + Domain: "example.com", + bootstrapIPs: []string{"1.1.1.1", "1.0.0.1"}, + } + + uc.SetupTransport() + + if uc.transport == nil { + t.Fatal("initial transport should be set") + } + + const goroutines = 100 + + uc.ReBootstrap() + + started := make(chan struct{}) + go func() { + close(started) + for { + switch uc.rebootstrap.Load() { + case rebootstrapStarted, rebootstrapInProgress: + uc.ReBootstrap() + default: + return + } + } + }() + + <-started + + var wg sync.WaitGroup + for range goroutines { + wg.Go(func() { + uc.ensureSetupTransport() + }) + } + + wg.Wait() +} + func ptrBool(b bool) *bool { return &b }