diff --git a/doh.go b/doh.go index 242f759..367e419 100644 --- a/doh.go +++ b/doh.go @@ -8,6 +8,7 @@ import ( "io" "net/http" "net/url" + "runtime" "github.com/miekg/dns" ) @@ -27,6 +28,7 @@ func newDohResolver(uc *UpstreamConfig) *dohResolver { transport: uc.transport, http3RoundTripper: uc.http3RoundTripper, sendClientInfo: uc.UpstreamSendClientInfo(), + uc: uc, } return r } @@ -37,6 +39,7 @@ type dohResolver struct { transport *http.Transport http3RoundTripper http.RoundTripper sendClientInfo bool + uc *UpstreamConfig } func (r *dohResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) { @@ -57,14 +60,12 @@ func (r *dohResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, erro } addHeader(ctx, req, r.sendClientInfo) - c := http.Client{Transport: r.transport} - if r.isDoH3 { - if r.http3RoundTripper == nil { - return nil, errors.New("DoH3 is not supported") - } - c.Transport = r.http3RoundTripper + var resp *http.Response + if runtime.GOOS == "linux" { + resp, err = r.doRequestWithFailover(req) + } else { + resp, err = r.doRequest(req) } - resp, err := c.Do(req) if err != nil { if r.isDoH3 { if closer, ok := r.http3RoundTripper.(io.Closer); ok { @@ -91,6 +92,92 @@ func (r *dohResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, erro return answer, nil } +func (r *dohResolver) doRequest(req *http.Request) (*http.Response, error) { + c := http.Client{Transport: r.transport} + if r.isDoH3 { + if r.http3RoundTripper == nil { + return nil, errors.New("DoH3 is not supported") + } + c.Transport = r.http3RoundTripper + } + return c.Do(req) +} + +func (r *dohResolver) doRequestWithFailover(req *http.Request) (*http.Response, error) { + // To deal with network changes, for example, connect/disconnect to VPN, + // We use two clients: + // + // - mainClient: use the current transport. + // - failoverClient: use a clone of the current transport. + // + // Two clients will perform the requests concurrently, but with mainClient + // started first. So in normal condition, mainClient is likely to return first, + // and we will use its result. In case of mainClient failed, we trigger the + // re-bootstrapping process, and use the result from failover client. + mainClient := http.Client{Transport: r.transport} + failoverClient := http.Client{} + if r.isDoH3 { + if r.http3RoundTripper == nil { + return nil, errors.New("DoH3 is not supported") + } + // TODO: figure out how to clone DOH3 round tripper? + mainClient.Transport = r.http3RoundTripper + failoverClient.Transport = r.http3RoundTripper + } else { + failoverClient.Transport = r.transport.Clone() + } + + done := make(chan struct{}) + defer close(done) + + type result struct { + resp *http.Response + err error + } + + respCh := make(chan result) + doRequest := func(client *http.Client) { + resp, err := client.Do(req) + select { + case respCh <- result{resp: resp, err: err}: + case <-done: + if client == &mainClient && err != nil { + r.uc.ReBootstrap() + } + if resp != nil { + defer resp.Body.Close() + _, _ = io.Copy(io.Discard, resp.Body) + } + } + } + + mainClientStarted := make(chan struct{}) + go func() { + // Notify failoverClient that mainClient started. + close(mainClientStarted) + doRequest(&mainClient) + }() + go func() { + // Wait mainClient started first. + <-mainClientStarted + doRequest(&failoverClient) + }() + + var ( + resp *http.Response + err error + ) + for range []*http.Client{&mainClient, &failoverClient} { + res := <-respCh + if res.err == nil { + resp = res.resp + break + } + err = res.err + } + return resp, err +} + func addHeader(ctx context.Context, req *http.Request, sendClientInfo bool) { req.Header.Set("Content-Type", headerApplicationDNS) req.Header.Set("Accept", headerApplicationDNS)