From 8dd90cb354715230c6385ed44258c040fe652bb9 Mon Sep 17 00:00:00 2001 From: Cuong Manh Le Date: Wed, 7 Jan 2026 17:11:38 +0700 Subject: [PATCH] fix(config): use three-state atomic for rebootstrap to prevent data race Replace boolean rebootstrap flag with a three-state atomic integer to prevent concurrent SetupTransport calls during rebootstrap. The atomic state machine ensures only one goroutine can proceed from "started" to "in progress", eliminating the need for a mutex while maintaining thread safety. States: NotStarted -> Started -> InProgress -> NotStarted Note that the race condition is still acceptable because any additional transports created during the race are functional. Once the connection is established, the unused transports are safely handled by the garbage collector. --- config.go | 12 ++++++++--- config_internal_test.go | 47 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 56 insertions(+), 3 deletions(-) diff --git a/config.go b/config.go index 63c7f6a..c7ad161 100644 --- a/config.go +++ b/config.go @@ -82,6 +82,10 @@ const ( endpointPrefixQUIC = "quic://" endpointPrefixH3 = "h3://" endpointPrefixSdns = "sdns://" + + rebootstrapNotStarted = 0 + rebootstrapStarted = 1 + rebootstrapInProgress = 2 ) var ( @@ -270,7 +274,7 @@ type UpstreamConfig struct { Discoverable *bool `mapstructure:"discoverable" toml:"discoverable"` g singleflight.Group - rebootstrap atomic.Bool + rebootstrap atomic.Int64 bootstrapIPs []string bootstrapIPs4 []string bootstrapIPs6 []string @@ -511,7 +515,7 @@ func (uc *UpstreamConfig) ReBootstrap(ctx context.Context) { return } _, _, _ = uc.g.Do("ReBootstrap", func() (any, error) { - if uc.rebootstrap.CompareAndSwap(false, true) { + if uc.rebootstrap.CompareAndSwap(rebootstrapNotStarted, rebootstrapStarted) { logger := LoggerFromCtx(ctx) Log(ctx, logger.Debug(), "Re-bootstrapping upstream: %s", uc.Name) } @@ -557,8 +561,10 @@ func (uc *UpstreamConfig) ensureSetupTransport(ctx context.Context) { uc.transportOnce.Do(func() { uc.SetupTransport(ctx) }) - if uc.rebootstrap.CompareAndSwap(true, false) { + + if uc.rebootstrap.CompareAndSwap(rebootstrapStarted, rebootstrapInProgress) { uc.SetupTransport(ctx) + uc.rebootstrap.Store(rebootstrapNotStarted) } } diff --git a/config_internal_test.go b/config_internal_test.go index 0e7f3bb..24f85b6 100644 --- a/config_internal_test.go +++ b/config_internal_test.go @@ -3,6 +3,7 @@ package ctrld import ( "context" "net/url" + "sync" "testing" "github.com/stretchr/testify/assert" @@ -506,6 +507,52 @@ func TestUpstreamConfig_IsDiscoverable(t *testing.T) { } } +func TestRebootstrapRace(t *testing.T) { + uc := &UpstreamConfig{ + Name: "test-doh", + Type: ResolverTypeDOH, + Endpoint: "https://example.com/dns-query", + Domain: "example.com", + bootstrapIPs: []string{"1.1.1.1", "1.0.0.1"}, + } + + ctx := LoggerCtx(context.Background(), NopLogger) + + uc.SetupTransport(ctx) + + if uc.transport == nil { + t.Fatal("initial transport should be set") + } + + const goroutines = 100 + + uc.ReBootstrap(ctx) + + started := make(chan struct{}) + go func() { + close(started) + for { + switch uc.rebootstrap.Load() { + case rebootstrapStarted, rebootstrapInProgress: + uc.ReBootstrap(ctx) + default: + return + } + } + }() + + <-started + + var wg sync.WaitGroup + for range goroutines { + wg.Go(func() { + uc.ensureSetupTransport(ctx) + }) + } + + wg.Wait() +} + func ptrBool(b bool) *bool { return &b }