all: use parallel dialer for connecting upstream/api

So we don't have to depend on network stack probing to decide whether
ipv4 or ipv6 will be used.

While at it, also prevent a race report when doing the same parallel
resolving for os resolver, even though this race is harmless.
This commit is contained in:
Cuong Manh Le
2023-04-24 19:56:01 +07:00
committed by Cuong Manh Le
parent d3d08022cc
commit d52cd11322
6 changed files with 127 additions and 122 deletions
+45 -7
View File
@@ -2,6 +2,7 @@ package net
import (
"context"
"errors"
"net"
"sync"
"sync/atomic"
@@ -37,7 +38,6 @@ var probeStackDialer = &net.Dialer{
var (
stackOnce atomic.Pointer[sync.Once]
ipv4Enabled bool
ipv6Enabled bool
canListenIPv6Local bool
hasNetworkUp bool
@@ -75,7 +75,6 @@ func probeStack() {
b.BackOff(context.Background(), err)
}
}
ipv4Enabled = supportIPv4()
ipv6Enabled = supportIPv6(context.Background())
canListenIPv6Local = supportListenIPv6Local()
}
@@ -85,11 +84,6 @@ func Up() bool {
return hasNetworkUp
}
func SupportsIPv4() bool {
stackOnce.Load().Do(probeStack)
return ipv4Enabled
}
func SupportsIPv6() bool {
stackOnce.Load().Do(probeStack)
return ipv6Enabled
@@ -112,3 +106,47 @@ func IsIPv6(ip string) bool {
parsedIP := net.ParseIP(ip)
return parsedIP != nil && parsedIP.To4() == nil && parsedIP.To16() != nil
}
type parallelDialerResult struct {
conn net.Conn
err error
}
type ParallelDialer struct {
net.Dialer
}
func (d *ParallelDialer) DialContext(ctx context.Context, network string, addrs []string) (net.Conn, error) {
if len(addrs) == 0 {
return nil, errors.New("empty addresses")
}
ctx, cancel := context.WithCancel(ctx)
defer cancel()
ch := make(chan *parallelDialerResult, len(addrs))
var wg sync.WaitGroup
wg.Add(len(addrs))
go func() {
wg.Wait()
close(ch)
}()
for _, addr := range addrs {
go func(addr string) {
defer wg.Done()
conn, err := d.Dialer.DialContext(ctx, network, addr)
ch <- &parallelDialerResult{conn: conn, err: err}
}(addr)
}
errs := make([]error, 0, len(addrs))
for res := range ch {
if res.err == nil {
cancel()
return res.conn, res.err
}
errs = append(errs, res.err)
}
return nil, errors.Join(errs...)
}