diff --git a/config.go b/config.go index 63c7f6a..c7ad161 100644 --- a/config.go +++ b/config.go @@ -82,6 +82,10 @@ const ( endpointPrefixQUIC = "quic://" endpointPrefixH3 = "h3://" endpointPrefixSdns = "sdns://" + + rebootstrapNotStarted = 0 + rebootstrapStarted = 1 + rebootstrapInProgress = 2 ) var ( @@ -270,7 +274,7 @@ type UpstreamConfig struct { Discoverable *bool `mapstructure:"discoverable" toml:"discoverable"` g singleflight.Group - rebootstrap atomic.Bool + rebootstrap atomic.Int64 bootstrapIPs []string bootstrapIPs4 []string bootstrapIPs6 []string @@ -511,7 +515,7 @@ func (uc *UpstreamConfig) ReBootstrap(ctx context.Context) { return } _, _, _ = uc.g.Do("ReBootstrap", func() (any, error) { - if uc.rebootstrap.CompareAndSwap(false, true) { + if uc.rebootstrap.CompareAndSwap(rebootstrapNotStarted, rebootstrapStarted) { logger := LoggerFromCtx(ctx) Log(ctx, logger.Debug(), "Re-bootstrapping upstream: %s", uc.Name) } @@ -557,8 +561,10 @@ func (uc *UpstreamConfig) ensureSetupTransport(ctx context.Context) { uc.transportOnce.Do(func() { uc.SetupTransport(ctx) }) - if uc.rebootstrap.CompareAndSwap(true, false) { + + if uc.rebootstrap.CompareAndSwap(rebootstrapStarted, rebootstrapInProgress) { uc.SetupTransport(ctx) + uc.rebootstrap.Store(rebootstrapNotStarted) } } diff --git a/config_internal_test.go b/config_internal_test.go index 0e7f3bb..24f85b6 100644 --- a/config_internal_test.go +++ b/config_internal_test.go @@ -3,6 +3,7 @@ package ctrld import ( "context" "net/url" + "sync" "testing" "github.com/stretchr/testify/assert" @@ -506,6 +507,52 @@ func TestUpstreamConfig_IsDiscoverable(t *testing.T) { } } +func TestRebootstrapRace(t *testing.T) { + uc := &UpstreamConfig{ + Name: "test-doh", + Type: ResolverTypeDOH, + Endpoint: "https://example.com/dns-query", + Domain: "example.com", + bootstrapIPs: []string{"1.1.1.1", "1.0.0.1"}, + } + + ctx := LoggerCtx(context.Background(), NopLogger) + + uc.SetupTransport(ctx) + + if uc.transport == nil { + t.Fatal("initial transport should be set") + } + + const goroutines = 100 + + uc.ReBootstrap(ctx) + + started := make(chan struct{}) + go func() { + close(started) + for { + switch uc.rebootstrap.Load() { + case rebootstrapStarted, rebootstrapInProgress: + uc.ReBootstrap(ctx) + default: + return + } + } + }() + + <-started + + var wg sync.WaitGroup + for range goroutines { + wg.Go(func() { + uc.ensureSetupTransport(ctx) + }) + } + + wg.Wait() +} + func ptrBool(b bool) *bool { return &b }