Use different failover mechanism on Linux

Instead of always doubling the request, first we wrap the request with a
failover timeout, 500ms, which is an average time for a normal request.
If this request failed, trigger re-bootstrapping and retry the request.
This commit is contained in:
Cuong Manh Le
2023-04-29 13:15:10 +07:00
committed by Cuong Manh Le
parent d57c1d6d44
commit 56d8dc865f

82
doh.go
View File

@@ -9,6 +9,7 @@ import (
"net/http"
"net/url"
"runtime"
"time"
"github.com/miekg/dns"
)
@@ -103,79 +104,28 @@ func (r *dohResolver) doRequest(req *http.Request) (*http.Response, error) {
return c.Do(req)
}
const failoverTimeout = 500 * time.Millisecond
// doRequestWithFailover is like doRequest, but wrap the request with initial timeout.
// If the first request failed, it's likely that the transport was broken, then trigger
// re-bootstrapping and retry the request.
func (r *dohResolver) doRequestWithFailover(req *http.Request) (*http.Response, error) {
// To deal with network changes, for example, connect/disconnect to VPN,
// We use two clients:
//
// - mainClient: use the current transport.
// - failoverClient: use a clone of the current transport.
//
// Two clients will perform the requests concurrently, but with mainClient
// started first. So in normal condition, mainClient is likely to return first,
// and we will use its result. In case of mainClient failed, we trigger the
// re-bootstrapping process, and use the result from failover client.
mainClient := http.Client{Transport: r.transport}
failoverClient := http.Client{}
c := http.Client{Transport: r.transport}
if r.isDoH3 {
if r.http3RoundTripper == nil {
return nil, errors.New("DoH3 is not supported")
}
// TODO: figure out how to clone DOH3 round tripper?
mainClient.Transport = r.http3RoundTripper
failoverClient.Transport = r.http3RoundTripper
} else {
failoverClient.Transport = r.transport.Clone()
c.Transport = r.http3RoundTripper
}
done := make(chan struct{})
defer close(done)
type result struct {
resp *http.Response
err error
ctx, cancel := context.WithTimeout(context.Background(), failoverTimeout)
defer cancel()
resp, err := c.Do(req.WithContext(ctx))
if err == nil {
return resp, err
}
respCh := make(chan result)
doRequest := func(client *http.Client) {
resp, err := client.Do(req)
select {
case respCh <- result{resp: resp, err: err}:
case <-done:
if client == &mainClient && err != nil {
r.uc.ReBootstrap()
}
if resp != nil {
defer resp.Body.Close()
_, _ = io.Copy(io.Discard, resp.Body)
}
}
}
mainClientStarted := make(chan struct{})
go func() {
// Notify failoverClient that mainClient started.
close(mainClientStarted)
doRequest(&mainClient)
}()
go func() {
// Wait mainClient started first.
<-mainClientStarted
doRequest(&failoverClient)
}()
var (
resp *http.Response
err error
)
for range []*http.Client{&mainClient, &failoverClient} {
res := <-respCh
if res.err == nil {
resp = res.resp
break
}
err = res.err
}
return resp, err
r.uc.ReBootstrap()
c.Transport = r.uc.transport
return c.Do(req)
}
func addHeader(ctx context.Context, req *http.Request, sendClientInfo bool) {