From 56d8dc865f37f6d22feae21bcf5296b68d3c9a9d Mon Sep 17 00:00:00 2001 From: Cuong Manh Le Date: Sat, 29 Apr 2023 13:15:10 +0700 Subject: [PATCH] Use different failover mechanism on Linux Instead of always doubling the request, first we wrap the request with a failover timeout, 500ms, which is an average time for a normal request. If this request failed, trigger re-bootstrapping and retry the request. --- doh.go | 82 ++++++++++++---------------------------------------------- 1 file changed, 16 insertions(+), 66 deletions(-) diff --git a/doh.go b/doh.go index 367e419..76f81ea 100644 --- a/doh.go +++ b/doh.go @@ -9,6 +9,7 @@ import ( "net/http" "net/url" "runtime" + "time" "github.com/miekg/dns" ) @@ -103,79 +104,28 @@ func (r *dohResolver) doRequest(req *http.Request) (*http.Response, error) { return c.Do(req) } +const failoverTimeout = 500 * time.Millisecond + +// doRequestWithFailover is like doRequest, but wrap the request with initial timeout. +// If the first request failed, it's likely that the transport was broken, then trigger +// re-bootstrapping and retry the request. func (r *dohResolver) doRequestWithFailover(req *http.Request) (*http.Response, error) { - // To deal with network changes, for example, connect/disconnect to VPN, - // We use two clients: - // - // - mainClient: use the current transport. - // - failoverClient: use a clone of the current transport. - // - // Two clients will perform the requests concurrently, but with mainClient - // started first. So in normal condition, mainClient is likely to return first, - // and we will use its result. In case of mainClient failed, we trigger the - // re-bootstrapping process, and use the result from failover client. - mainClient := http.Client{Transport: r.transport} - failoverClient := http.Client{} + c := http.Client{Transport: r.transport} if r.isDoH3 { if r.http3RoundTripper == nil { return nil, errors.New("DoH3 is not supported") } - // TODO: figure out how to clone DOH3 round tripper? - mainClient.Transport = r.http3RoundTripper - failoverClient.Transport = r.http3RoundTripper - } else { - failoverClient.Transport = r.transport.Clone() + c.Transport = r.http3RoundTripper } - - done := make(chan struct{}) - defer close(done) - - type result struct { - resp *http.Response - err error + ctx, cancel := context.WithTimeout(context.Background(), failoverTimeout) + defer cancel() + resp, err := c.Do(req.WithContext(ctx)) + if err == nil { + return resp, err } - - respCh := make(chan result) - doRequest := func(client *http.Client) { - resp, err := client.Do(req) - select { - case respCh <- result{resp: resp, err: err}: - case <-done: - if client == &mainClient && err != nil { - r.uc.ReBootstrap() - } - if resp != nil { - defer resp.Body.Close() - _, _ = io.Copy(io.Discard, resp.Body) - } - } - } - - mainClientStarted := make(chan struct{}) - go func() { - // Notify failoverClient that mainClient started. - close(mainClientStarted) - doRequest(&mainClient) - }() - go func() { - // Wait mainClient started first. - <-mainClientStarted - doRequest(&failoverClient) - }() - - var ( - resp *http.Response - err error - ) - for range []*http.Client{&mainClient, &failoverClient} { - res := <-respCh - if res.err == nil { - resp = res.resp - break - } - err = res.err - } - return resp, err + r.uc.ReBootstrap() + c.Transport = r.uc.transport + return c.Do(req) } func addHeader(ctx context.Context, req *http.Request, sendClientInfo bool) {