diff --git a/cmd/cli/dns_proxy.go b/cmd/cli/dns_proxy.go index 8053a89..33ca60c 100644 --- a/cmd/cli/dns_proxy.go +++ b/cmd/cli/dns_proxy.go @@ -1574,83 +1574,131 @@ func (p *prog) checkUpstreamOnce(upstream string, uc *ctrld.UpstreamConfig) erro return err } -// handleRecovery performs a unified recovery by removing DNS settings, -// canceling existing recovery checks for network changes, but coalescing duplicate -// upstream failure recoveries, waiting for recovery to complete (using a cancellable context without timeout), -// and then re-applying the DNS settings. +// handleRecovery orchestrates the recovery process by coordinating multiple smaller methods. +// It handles recovery cancellation logic, creates recovery context, prepares the system, +// waits for upstream recovery with timeout, and completes the recovery process. +// The method is designed to be called from a goroutine and handles different recovery reasons +// (network changes, regular failures, OS failures) with appropriate logic for each. func (p *prog) handleRecovery(reason RecoveryReason) { p.Debug().Msg("Starting recovery process: removing DNS settings") - // For network changes, cancel any existing recovery check because the network state has changed. + // Handle recovery cancellation based on reason + if !p.shouldStartRecovery(reason) { + return + } + + // Create recovery context and cleanup function + recoveryCtx, cleanup := p.createRecoveryContext() + defer cleanup() + + // Remove DNS settings and prepare for recovery + if err := p.prepareForRecovery(reason); err != nil { + p.Error().Err(err).Msg("Failed to prepare for recovery") + return + } + + // Build upstream map based on the recovery reason + upstreams := p.buildRecoveryUpstreams(reason) + + // Wait for upstream recovery + recovered, err := p.waitForUpstreamRecovery(recoveryCtx, upstreams) + if err != nil { + p.Error().Err(err).Msg("Recovery failed; DNS settings remain removed") + return + } + + // Complete recovery process + if err := p.completeRecovery(reason, recovered); err != nil { + p.Error().Err(err).Msg("Failed to complete recovery") + return + } + + p.Info().Msgf("Recovery completed successfully for upstream %q", recovered) +} + +// shouldStartRecovery determines if recovery should start based on the reason and current state. +// Returns true if recovery should proceed, false otherwise. +func (p *prog) shouldStartRecovery(reason RecoveryReason) bool { + p.recoveryCancelMu.Lock() + defer p.recoveryCancelMu.Unlock() + if reason == RecoveryReasonNetworkChange { - p.recoveryCancelMu.Lock() + // For network changes, cancel any existing recovery check because the network state has changed. if p.recoveryCancel != nil { p.Debug().Msg("Cancelling existing recovery check (network change)") p.recoveryCancel() p.recoveryCancel = nil } - p.recoveryCancelMu.Unlock() - } else { - // For upstream failures, if a recovery is already in progress, do nothing new. - p.recoveryCancelMu.Lock() - if p.recoveryCancel != nil { - p.Debug().Msg("Upstream recovery already in progress; skipping duplicate trigger") - p.recoveryCancelMu.Unlock() - return - } - p.recoveryCancelMu.Unlock() + return true } - // Create a new recovery context without a fixed timeout. + // For upstream failures, if a recovery is already in progress, do nothing new. + if p.recoveryCancel != nil { + p.Debug().Msg("Upstream recovery already in progress; skipping duplicate trigger") + return false + } + + return true +} + +// createRecoveryContext creates a new recovery context and returns it along with a cleanup function. +func (p *prog) createRecoveryContext() (context.Context, func()) { p.recoveryCancelMu.Lock() recoveryCtx, cancel := context.WithCancel(context.Background()) p.recoveryCancel = cancel p.recoveryCancelMu.Unlock() - // Immediately remove our DNS settings from the interface. - // set recoveryRunning to true to prevent watchdogs from putting the listener back on the interface + cleanup := func() { + p.recoveryCancelMu.Lock() + p.recoveryCancel = nil + p.recoveryCancelMu.Unlock() + } + + return recoveryCtx, cleanup +} + +// prepareForRecovery removes DNS settings and initializes OS resolver if needed. +func (p *prog) prepareForRecovery(reason RecoveryReason) error { + // Set recoveryRunning to true to prevent watchdogs from putting the listener back on the interface p.recoveryRunning.Store(true) - // we do not want to restore any static DNS settings + + // Remove DNS settings - we do not want to restore any static DNS settings // we must try to get the DHCP values, any static DNS settings // will be appended to nameservers from the saved interface values p.resetDNS(false, false) - loggerCtx := ctrld.LoggerCtx(context.Background(), p.logger.Load()) // For an OS failure, reinitialize OS resolver nameservers immediately. if reason == RecoveryReasonOSFailure { - p.Debug().Msg("OS resolver failure detected; reinitializing OS resolver nameservers") - ns := ctrld.InitializeOsResolver(loggerCtx, true) - if len(ns) == 0 { - p.Warn().Msg("No nameservers found for OS resolver; using existing values") - } else { - p.Info().Msgf("Reinitialized OS resolver with nameservers: %v", ns) + if err := p.reinitializeOSResolver("OS resolver failure detected"); err != nil { + return fmt.Errorf("failed to reinitialize OS resolver: %w", err) } } - // Build upstream map based on the recovery reason. - upstreams := p.buildRecoveryUpstreams(reason) + return nil +} - // Wait indefinitely until one of the upstreams recovers. - recovered, err := p.waitForUpstreamRecovery(recoveryCtx, upstreams) - if err != nil { - p.Error().Err(err).Msg("Recovery canceled; DNS settings remain removed") - p.recoveryCancelMu.Lock() - p.recoveryCancel = nil - p.recoveryCancelMu.Unlock() - return +// reinitializeOSResolver reinitializes the OS resolver and logs the results. +func (p *prog) reinitializeOSResolver(message string) error { + p.Debug().Msg(message) + loggerCtx := ctrld.LoggerCtx(context.Background(), p.logger.Load()) + ns := ctrld.InitializeOsResolver(loggerCtx, true) + if len(ns) == 0 { + p.Warn().Msg("No nameservers found for OS resolver; using existing values") + } else { + p.Info().Msgf("Reinitialized OS resolver with nameservers: %v", ns) } - p.Info().Msgf("Upstream %q recovered; re-applying DNS settings", recovered) + return nil +} - // reset the upstream failure count and down state +// completeRecovery completes the recovery process by resetting upstream state and reapplying DNS settings. +func (p *prog) completeRecovery(reason RecoveryReason, recovered string) error { + // Reset the upstream failure count and down state p.um.reset(recovered) // For network changes we also reinitialize the OS resolver. if reason == RecoveryReasonNetworkChange { - ns := ctrld.InitializeOsResolver(loggerCtx, true) - if len(ns) == 0 { - p.Warn().Msg("No nameservers found for OS resolver during network-change recovery; using existing values") - } else { - p.Info().Msgf("Reinitialized OS resolver with nameservers: %v", ns) + if err := p.reinitializeOSResolver("Network change detected during recovery"); err != nil { + return fmt.Errorf("failed to reinitialize OS resolver during network change: %w", err) } } @@ -1658,13 +1706,10 @@ func (p *prog) handleRecovery(reason RecoveryReason) { p.setDNS() p.logInterfacesState() - // allow watchdogs to put the listener back on the interface if its changed for any reason + // Allow watchdogs to put the listener back on the interface if it's changed for any reason p.recoveryRunning.Store(false) - // Clear the recovery cancellation for a clean slate. - p.recoveryCancelMu.Lock() - p.recoveryCancel = nil - p.recoveryCancelMu.Unlock() + return nil } // waitForUpstreamRecovery checks the provided upstreams concurrently until one recovers. diff --git a/cmd/cli/dns_proxy_test.go b/cmd/cli/dns_proxy_test.go index 615ce40..75db216 100644 --- a/cmd/cli/dns_proxy_test.go +++ b/cmd/cli/dns_proxy_test.go @@ -466,3 +466,254 @@ func Test_isWanClient(t *testing.T) { }) } } + +func Test_shouldStartRecovery(t *testing.T) { + tests := []struct { + name string + reason RecoveryReason + hasExistingRecovery bool + expectedResult bool + description string + }{ + { + name: "network change with existing recovery", + reason: RecoveryReasonNetworkChange, + hasExistingRecovery: true, + expectedResult: true, + description: "should cancel existing recovery and start new one for network change", + }, + { + name: "network change without existing recovery", + reason: RecoveryReasonNetworkChange, + hasExistingRecovery: false, + expectedResult: true, + description: "should start new recovery for network change", + }, + { + name: "regular failure with existing recovery", + reason: RecoveryReasonRegularFailure, + hasExistingRecovery: true, + expectedResult: false, + description: "should skip duplicate recovery for regular failure", + }, + { + name: "regular failure without existing recovery", + reason: RecoveryReasonRegularFailure, + hasExistingRecovery: false, + expectedResult: true, + description: "should start new recovery for regular failure", + }, + { + name: "OS failure with existing recovery", + reason: RecoveryReasonOSFailure, + hasExistingRecovery: true, + expectedResult: false, + description: "should skip duplicate recovery for OS failure", + }, + { + name: "OS failure without existing recovery", + reason: RecoveryReasonOSFailure, + hasExistingRecovery: false, + expectedResult: true, + description: "should start new recovery for OS failure", + }, + } + + for _, tc := range tests { + tc := tc + t.Run(tc.name, func(t *testing.T) { + p := newTestProg(t) + + // Setup existing recovery if needed + if tc.hasExistingRecovery { + p.recoveryCancelMu.Lock() + p.recoveryCancel = func() {} // Mock cancel function + p.recoveryCancelMu.Unlock() + } + + result := p.shouldStartRecovery(tc.reason) + assert.Equal(t, tc.expectedResult, result, tc.description) + }) + } +} + +func Test_createRecoveryContext(t *testing.T) { + p := newTestProg(t) + + ctx, cleanup := p.createRecoveryContext() + + // Verify context is created + assert.NotNil(t, ctx) + assert.NotNil(t, cleanup) + + // Verify recoveryCancel is set + p.recoveryCancelMu.Lock() + assert.NotNil(t, p.recoveryCancel) + p.recoveryCancelMu.Unlock() + + // Test cleanup function + cleanup() + + // Verify recoveryCancel is cleared + p.recoveryCancelMu.Lock() + assert.Nil(t, p.recoveryCancel) + p.recoveryCancelMu.Unlock() +} + +func Test_prepareForRecovery(t *testing.T) { + tests := []struct { + name string + reason RecoveryReason + wantErr bool + }{ + { + name: "regular failure", + reason: RecoveryReasonRegularFailure, + wantErr: false, + }, + { + name: "network change", + reason: RecoveryReasonNetworkChange, + wantErr: false, + }, + { + name: "OS failure", + reason: RecoveryReasonOSFailure, + wantErr: false, + }, + } + + for _, tc := range tests { + tc := tc + t.Run(tc.name, func(t *testing.T) { + p := newTestProg(t) + + err := p.prepareForRecovery(tc.reason) + + if tc.wantErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + + // Verify recoveryRunning is set to true + assert.True(t, p.recoveryRunning.Load()) + }) + } +} + +func Test_completeRecovery(t *testing.T) { + tests := []struct { + name string + reason RecoveryReason + recovered string + wantErr bool + }{ + { + name: "regular failure recovery", + reason: RecoveryReasonRegularFailure, + recovered: "upstream1", + wantErr: false, + }, + { + name: "network change recovery", + reason: RecoveryReasonNetworkChange, + recovered: "upstream2", + wantErr: false, + }, + { + name: "OS failure recovery", + reason: RecoveryReasonOSFailure, + recovered: "upstream3", + wantErr: false, + }, + } + + for _, tc := range tests { + tc := tc + t.Run(tc.name, func(t *testing.T) { + p := newTestProg(t) + + err := p.completeRecovery(tc.reason, tc.recovered) + + if tc.wantErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + + // Verify recoveryRunning is set to false + assert.False(t, p.recoveryRunning.Load()) + }) + } +} + +func Test_reinitializeOSResolver(t *testing.T) { + p := newTestProg(t) + + err := p.reinitializeOSResolver("Test message") + + // This function should not return an error under normal circumstances + // The actual behavior depends on the OS resolver implementation + assert.NoError(t, err) +} + +func Test_handleRecovery_Integration(t *testing.T) { + tests := []struct { + name string + reason RecoveryReason + wantErr bool + }{ + { + name: "network change recovery", + reason: RecoveryReasonNetworkChange, + wantErr: false, + }, + { + name: "regular failure recovery", + reason: RecoveryReasonRegularFailure, + wantErr: false, + }, + { + name: "OS failure recovery", + reason: RecoveryReasonOSFailure, + wantErr: false, + }, + } + + for _, tc := range tests { + tc := tc + t.Run(tc.name, func(t *testing.T) { + p := newTestProg(t) + + // This is an integration test that exercises the full recovery flow + // In a real test environment, you would mock the dependencies + // For now, we're just testing that the method doesn't panic + // and that the recovery logic flows correctly + assert.NotPanics(t, func() { + // Test only the preparation phase to avoid actual upstream checking + if !p.shouldStartRecovery(tc.reason) { + return + } + + _, cleanup := p.createRecoveryContext() + defer cleanup() + + if err := p.prepareForRecovery(tc.reason); err != nil { + return + } + + // Skip the actual upstream recovery check for this test + // as it requires properly configured upstreams + }) + }) + } +} + +// newTestProg creates a properly initialized *prog for testing. +func newTestProg(t *testing.T) *prog { + p := &prog{cfg: testhelper.SampleConfig(t)} + p.logger.Store(mainLog.Load()) + p.um = newUpstreamMonitor(p.cfg, mainLog.Load()) + return p +} diff --git a/go.mod b/go.mod index 1d94a07..a911c76 100644 --- a/go.mod +++ b/go.mod @@ -32,7 +32,6 @@ require ( github.com/quic-go/quic-go v0.48.2 github.com/rs/zerolog v1.28.0 github.com/spf13/cobra v1.8.1 - github.com/spf13/pflag v1.0.5 github.com/spf13/viper v1.16.0 github.com/stretchr/testify v1.9.0 github.com/vishvananda/netlink v1.2.1-beta.2 @@ -86,6 +85,7 @@ require ( github.com/spf13/afero v1.9.5 // indirect github.com/spf13/cast v1.6.0 // indirect github.com/spf13/jwalterweatherman v1.1.0 // indirect + github.com/spf13/pflag v1.0.5 // indirect github.com/subosito/gotenv v1.4.2 // indirect github.com/u-root/uio v0.0.0-20240118234441-a3c409a6018e // indirect github.com/vishvananda/netns v0.0.4 // indirect