Workaround for DOH broken transport when network changes

When network changes, for example: connect/disconnect VPN, the old
connection will become broken, but still can be re-used for new
requests. That would cause un-necessary delay for ctrld clients:

 - Time 0   - do request with broken transport, 5s timeout.
 - Time 0.5 - network stack become usable.
 - Time 5   - timeout reached.
 - Time 5.1 - do request with new transport -> success.

Instead, we can do two requests in parallel, with the failover one using
a fresh new transport. So if the main one is broken, we still can get
the result from the failover one.
This commit is contained in:
Cuong Manh Le
2023-04-29 03:19:56 +07:00
committed by Cuong Manh Le
parent 02fa7fbe2e
commit d57c1d6d44

101
doh.go
View File

@@ -8,6 +8,7 @@ import (
"io"
"net/http"
"net/url"
"runtime"
"github.com/miekg/dns"
)
@@ -27,6 +28,7 @@ func newDohResolver(uc *UpstreamConfig) *dohResolver {
transport: uc.transport,
http3RoundTripper: uc.http3RoundTripper,
sendClientInfo: uc.UpstreamSendClientInfo(),
uc: uc,
}
return r
}
@@ -37,6 +39,7 @@ type dohResolver struct {
transport *http.Transport
http3RoundTripper http.RoundTripper
sendClientInfo bool
uc *UpstreamConfig
}
func (r *dohResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) {
@@ -57,14 +60,12 @@ func (r *dohResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, erro
}
addHeader(ctx, req, r.sendClientInfo)
c := http.Client{Transport: r.transport}
if r.isDoH3 {
if r.http3RoundTripper == nil {
return nil, errors.New("DoH3 is not supported")
}
c.Transport = r.http3RoundTripper
var resp *http.Response
if runtime.GOOS == "linux" {
resp, err = r.doRequestWithFailover(req)
} else {
resp, err = r.doRequest(req)
}
resp, err := c.Do(req)
if err != nil {
if r.isDoH3 {
if closer, ok := r.http3RoundTripper.(io.Closer); ok {
@@ -91,6 +92,92 @@ func (r *dohResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, erro
return answer, nil
}
func (r *dohResolver) doRequest(req *http.Request) (*http.Response, error) {
c := http.Client{Transport: r.transport}
if r.isDoH3 {
if r.http3RoundTripper == nil {
return nil, errors.New("DoH3 is not supported")
}
c.Transport = r.http3RoundTripper
}
return c.Do(req)
}
func (r *dohResolver) doRequestWithFailover(req *http.Request) (*http.Response, error) {
// To deal with network changes, for example, connect/disconnect to VPN,
// We use two clients:
//
// - mainClient: use the current transport.
// - failoverClient: use a clone of the current transport.
//
// Two clients will perform the requests concurrently, but with mainClient
// started first. So in normal condition, mainClient is likely to return first,
// and we will use its result. In case of mainClient failed, we trigger the
// re-bootstrapping process, and use the result from failover client.
mainClient := http.Client{Transport: r.transport}
failoverClient := http.Client{}
if r.isDoH3 {
if r.http3RoundTripper == nil {
return nil, errors.New("DoH3 is not supported")
}
// TODO: figure out how to clone DOH3 round tripper?
mainClient.Transport = r.http3RoundTripper
failoverClient.Transport = r.http3RoundTripper
} else {
failoverClient.Transport = r.transport.Clone()
}
done := make(chan struct{})
defer close(done)
type result struct {
resp *http.Response
err error
}
respCh := make(chan result)
doRequest := func(client *http.Client) {
resp, err := client.Do(req)
select {
case respCh <- result{resp: resp, err: err}:
case <-done:
if client == &mainClient && err != nil {
r.uc.ReBootstrap()
}
if resp != nil {
defer resp.Body.Close()
_, _ = io.Copy(io.Discard, resp.Body)
}
}
}
mainClientStarted := make(chan struct{})
go func() {
// Notify failoverClient that mainClient started.
close(mainClientStarted)
doRequest(&mainClient)
}()
go func() {
// Wait mainClient started first.
<-mainClientStarted
doRequest(&failoverClient)
}()
var (
resp *http.Response
err error
)
for range []*http.Client{&mainClient, &failoverClient} {
res := <-respCh
if res.err == nil {
resp = res.resp
break
}
err = res.err
}
return resp, err
}
func addHeader(ctx context.Context, req *http.Request, sendClientInfo bool) {
req.Header.Set("Content-Type", headerApplicationDNS)
req.Header.Set("Accept", headerApplicationDNS)