diff --git a/config.go b/config.go index 904eaf1..b30d8a4 100644 --- a/config.go +++ b/config.go @@ -110,6 +110,7 @@ type UpstreamConfig struct { transport *http.Transport `mapstructure:"-" toml:"-"` http3RoundTripper http.RoundTripper `mapstructure:"-" toml:"-"` certPool *x509.CertPool `mapstructure:"-" toml:"-"` + u *url.URL `mapstructure:"-" toml:"-"` g singleflight.Group bootstrapIPs []string @@ -142,6 +143,10 @@ type Rule map[string][]string func (uc *UpstreamConfig) Init() { if u, err := url.Parse(uc.Endpoint); err == nil { uc.Domain = u.Host + switch uc.Type { + case ResolverTypeDOH, ResolverTypeDOH3: + uc.u = u + } } if uc.Domain != "" { return diff --git a/config_internal_test.go b/config_internal_test.go index 0a457d3..a470cf8 100644 --- a/config_internal_test.go +++ b/config_internal_test.go @@ -1,7 +1,10 @@ package ctrld import ( + "net/url" "testing" + + "github.com/stretchr/testify/assert" ) func TestUpstreamConfig_SetupBootstrapIP(t *testing.T) { @@ -19,3 +22,139 @@ func TestUpstreamConfig_SetupBootstrapIP(t *testing.T) { } t.Log(uc) } + +func TestUpstreamConfig_Init(t *testing.T) { + u1, _ := url.Parse("https://example.com") + u2, _ := url.Parse("https://example.com?k=v") + tests := []struct { + name string + uc *UpstreamConfig + expected *UpstreamConfig + }{ + { + "doh+doh3", + &UpstreamConfig{ + Name: "doh", + Type: "doh", + Endpoint: "https://example.com", + BootstrapIP: "", + Domain: "", + Timeout: 0, + }, + &UpstreamConfig{ + Name: "doh", + Type: "doh", + Endpoint: "https://example.com", + BootstrapIP: "", + Domain: "example.com", + Timeout: 0, + u: u1, + }, + }, + { + "doh+doh3 with query param", + &UpstreamConfig{ + Name: "doh", + Type: "doh", + Endpoint: "https://example.com?k=v", + BootstrapIP: "", + Domain: "", + Timeout: 0, + }, + &UpstreamConfig{ + Name: "doh", + Type: "doh", + Endpoint: "https://example.com?k=v", + BootstrapIP: "", + Domain: "example.com", + Timeout: 0, + u: u2, + }, + }, + { + "dot+doq", + &UpstreamConfig{ + Name: "dot", + Type: "dot", + Endpoint: "freedns.controld.com:8853", + BootstrapIP: "", + Domain: "", + Timeout: 0, + }, + &UpstreamConfig{ + Name: "dot", + Type: "dot", + Endpoint: "freedns.controld.com:8853", + BootstrapIP: "", + Domain: "freedns.controld.com", + Timeout: 0, + }, + }, + { + "dot+doq without port", + &UpstreamConfig{ + Name: "dot", + Type: "dot", + Endpoint: "freedns.controld.com", + BootstrapIP: "", + Domain: "", + Timeout: 0, + }, + &UpstreamConfig{ + Name: "dot", + Type: "dot", + Endpoint: "freedns.controld.com:853", + BootstrapIP: "", + Domain: "freedns.controld.com", + Timeout: 0, + }, + }, + { + "legacy", + &UpstreamConfig{ + Name: "legacy", + Type: "legacy", + Endpoint: "1.2.3.4:53", + BootstrapIP: "", + Domain: "", + Timeout: 0, + }, + &UpstreamConfig{ + Name: "legacy", + Type: "legacy", + Endpoint: "1.2.3.4:53", + BootstrapIP: "1.2.3.4", + Domain: "1.2.3.4", + Timeout: 0, + }, + }, + { + "legacy without port", + &UpstreamConfig{ + Name: "legacy", + Type: "legacy", + Endpoint: "1.2.3.4", + BootstrapIP: "", + Domain: "", + Timeout: 0, + }, + &UpstreamConfig{ + Name: "legacy", + Type: "legacy", + Endpoint: "1.2.3.4:53", + BootstrapIP: "1.2.3.4", + Domain: "1.2.3.4", + Timeout: 0, + }, + }, + } + + for _, tc := range tests { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + tc.uc.Init() + assert.Equal(t, tc.expected, tc.uc) + }) + } +} diff --git a/config_test.go b/config_test.go index f315c92..e2d75c0 100644 --- a/config_test.go +++ b/config_test.go @@ -165,116 +165,3 @@ func configWithInvalidRcodes(t *testing.T) *ctrld.Config { } return cfg } - -func TestUpstreamConfig_Init(t *testing.T) { - tests := []struct { - name string - uc *ctrld.UpstreamConfig - expected *ctrld.UpstreamConfig - }{ - { - "doh+doh3", - &ctrld.UpstreamConfig{ - Name: "doh", - Type: "doh", - Endpoint: "https://example.com", - BootstrapIP: "", - Domain: "", - Timeout: 0, - }, - &ctrld.UpstreamConfig{ - Name: "doh", - Type: "doh", - Endpoint: "https://example.com", - BootstrapIP: "", - Domain: "example.com", - Timeout: 0, - }, - }, - { - "dot+doq", - &ctrld.UpstreamConfig{ - Name: "dot", - Type: "dot", - Endpoint: "freedns.controld.com:8853", - BootstrapIP: "", - Domain: "", - Timeout: 0, - }, - &ctrld.UpstreamConfig{ - Name: "dot", - Type: "dot", - Endpoint: "freedns.controld.com:8853", - BootstrapIP: "", - Domain: "freedns.controld.com", - Timeout: 0, - }, - }, - { - "dot+doq without port", - &ctrld.UpstreamConfig{ - Name: "dot", - Type: "dot", - Endpoint: "freedns.controld.com", - BootstrapIP: "", - Domain: "", - Timeout: 0, - }, - &ctrld.UpstreamConfig{ - Name: "dot", - Type: "dot", - Endpoint: "freedns.controld.com:853", - BootstrapIP: "", - Domain: "freedns.controld.com", - Timeout: 0, - }, - }, - { - "legacy", - &ctrld.UpstreamConfig{ - Name: "legacy", - Type: "legacy", - Endpoint: "1.2.3.4:53", - BootstrapIP: "", - Domain: "", - Timeout: 0, - }, - &ctrld.UpstreamConfig{ - Name: "legacy", - Type: "legacy", - Endpoint: "1.2.3.4:53", - BootstrapIP: "1.2.3.4", - Domain: "1.2.3.4", - Timeout: 0, - }, - }, - { - "legacy without port", - &ctrld.UpstreamConfig{ - Name: "legacy", - Type: "legacy", - Endpoint: "1.2.3.4", - BootstrapIP: "", - Domain: "", - Timeout: 0, - }, - &ctrld.UpstreamConfig{ - Name: "legacy", - Type: "legacy", - Endpoint: "1.2.3.4:53", - BootstrapIP: "1.2.3.4", - Domain: "1.2.3.4", - Timeout: 0, - }, - }, - } - - for _, tc := range tests { - tc := tc - t.Run(tc.name, func(t *testing.T) { - t.Parallel() - tc.uc.Init() - assert.Equal(t, tc.expected, tc.uc) - }) - } -} diff --git a/doh.go b/doh.go index 433cf41..4fd4bd6 100644 --- a/doh.go +++ b/doh.go @@ -7,13 +7,14 @@ import ( "fmt" "io" "net/http" + "net/url" "github.com/miekg/dns" ) func newDohResolver(uc *UpstreamConfig) *dohResolver { r := &dohResolver{ - endpoint: uc.Endpoint, + endpoint: uc.u, isDoH3: uc.Type == ResolverTypeDOH3, transport: uc.transport, http3RoundTripper: uc.http3RoundTripper, @@ -22,7 +23,7 @@ func newDohResolver(uc *UpstreamConfig) *dohResolver { } type dohResolver struct { - endpoint string + endpoint *url.URL isDoH3 bool transport *http.Transport http3RoundTripper http.RoundTripper @@ -33,9 +34,14 @@ func (r *dohResolver) Resolve(ctx context.Context, msg *dns.Msg) (*dns.Msg, erro if err != nil { return nil, err } + enc := base64.RawURLEncoding.EncodeToString(data) - url := fmt.Sprintf("%s?dns=%s", r.endpoint, enc) - req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) + query := r.endpoint.Query() + query.Add("dns", enc) + + endpoint := *r.endpoint + endpoint.RawQuery = query.Encode() + req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint.String(), nil) if err != nil { return nil, fmt.Errorf("could not create request: %w", err) }