mirror of
https://github.com/Control-D-Inc/ctrld.git
synced 2026-04-07 12:32:04 +02:00
Compare commits
608 Commits
issue-44
...
release-br
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
839b8236e7 | ||
|
|
3f59cdad1a | ||
|
|
c55e2a722c | ||
|
|
22a796f673 | ||
|
|
95dd871e2d | ||
|
|
5c0585b2e8 | ||
|
|
112d1cb5a9 | ||
|
|
bd9bb90dd4 | ||
|
|
82fc628bf3 | ||
|
|
2926c76b76 | ||
|
|
fe08f00746 | ||
|
|
9be15aeec8 | ||
|
|
9b2e51f53a | ||
|
|
e7040bd9f9 | ||
|
|
768cc81855 | ||
|
|
289a46dc2c | ||
|
|
1e8240bd1c | ||
|
|
12715e6f24 | ||
|
|
147106f2b9 | ||
|
|
a4f0418811 | ||
|
|
40c68a13a1 | ||
|
|
3f30ec30d8 | ||
|
|
4790eb2c88 | ||
|
|
da3ea05763 | ||
|
|
209c9211b9 | ||
|
|
acbebcf7c2 | ||
|
|
2e8a0f00a0 | ||
|
|
1f4c47318e | ||
|
|
e8d1a4604e | ||
|
|
8d63a755ba | ||
|
|
f05519d1c8 | ||
|
|
1804e6db67 | ||
|
|
d0341497d1 | ||
|
|
27c5be43c2 | ||
|
|
3beffd0dc8 | ||
|
|
1f9c586444 | ||
|
|
a92e1ca024 | ||
|
|
705df72110 | ||
|
|
22122c45b2 | ||
|
|
57a9bb9fab | ||
|
|
78ea2d6361 | ||
|
|
df3cf7ef62 | ||
|
|
80e652b8d9 | ||
|
|
091c7edb19 | ||
|
|
6c550b1d74 | ||
|
|
3ca559e5a4 | ||
|
|
0e3f764299 | ||
|
|
e52402eb0c | ||
|
|
2133f31854 | ||
|
|
a198a5cd65 | ||
|
|
eb2b231bd2 | ||
|
|
7af29cfbc0 | ||
|
|
ce1a165348 | ||
|
|
fd48e6d795 | ||
|
|
d71d1341b6 | ||
|
|
21855df4af | ||
|
|
66e2d3a40a | ||
|
|
26257cf24a | ||
|
|
36a7423634 | ||
|
|
e616091249 | ||
|
|
0948161529 | ||
|
|
ce29b5d217 | ||
|
|
de24fa293e | ||
|
|
6663925c4d | ||
|
|
b9ece6d7b9 | ||
|
|
c4efa1ab97 | ||
|
|
7cea5305e1 | ||
|
|
a20fbf95de | ||
|
|
628c4302aa | ||
|
|
8dc34f8bf5 | ||
|
|
b4faf82f76 | ||
|
|
a983dfaee2 | ||
|
|
62f73bcaa2 | ||
|
|
00e9d2bdd3 | ||
|
|
ace3b1e66e | ||
|
|
d1ea1ba08c | ||
|
|
c06c8aa859 | ||
|
|
0c2cc00c4f | ||
|
|
8d6ea91f35 | ||
|
|
7dfb77228f | ||
|
|
24910f1fa6 | ||
|
|
433a61d2ee | ||
|
|
3937e885f0 | ||
|
|
c651003cc4 | ||
|
|
b7ccfcb8b4 | ||
|
|
a9ed70200b | ||
|
|
c6365e6b74 | ||
|
|
dacc67e50f | ||
|
|
c60cf33af3 | ||
|
|
f27cbe3525 | ||
|
|
2de1b9929a | ||
|
|
8bf654aece | ||
|
|
84376ed719 | ||
|
|
7a136b8874 | ||
|
|
58c0e4f15a | ||
|
|
e0d35d8ba2 | ||
|
|
3b2e48761e | ||
|
|
b27064008e | ||
|
|
1ad63827e1 | ||
|
|
20e61550c2 | ||
|
|
020b814402 | ||
|
|
e578867118 | ||
|
|
46a1039f21 | ||
|
|
cc9e27de5f | ||
|
|
6ab3ab9faf | ||
|
|
e68bfa795a | ||
|
|
e60a92e93e | ||
|
|
62fe14f76b | ||
|
|
a0c5062e3a | ||
|
|
49eb152d02 | ||
|
|
b05056423a | ||
|
|
c7168739c7 | ||
|
|
5b1faf1ce3 | ||
|
|
513a6f9ec7 | ||
|
|
8db6fa4232 | ||
|
|
5036de2602 | ||
|
|
332f8ccc37 | ||
|
|
a582195cec | ||
|
|
9fe36ae984 | ||
|
|
54cb455522 | ||
|
|
8bd3b9e474 | ||
|
|
eff5ff580b | ||
|
|
c45f863ed8 | ||
|
|
414d4e356d | ||
|
|
ef697eb781 | ||
|
|
0631ffe831 | ||
|
|
7444d8517a | ||
|
|
3480043e40 | ||
|
|
619b6e7516 | ||
|
|
0123ca44fb | ||
|
|
7929aafe2a | ||
|
|
dc433f8dc9 | ||
|
|
8ccaeeab60 | ||
|
|
043a28eb33 | ||
|
|
c329402f5d | ||
|
|
23e6ad6e1f | ||
|
|
e6de78c1fa | ||
|
|
a670708f93 | ||
|
|
4ebe2fb5f4 | ||
|
|
3403b2039d | ||
|
|
e30ad31e0f | ||
|
|
81e0bad739 | ||
|
|
7d07d738dc | ||
|
|
0fae584e65 | ||
|
|
9e83085f2a | ||
|
|
41a00c68ac | ||
|
|
e3b99bf339 | ||
|
|
5007a87d3a | ||
|
|
60e65a37a6 | ||
|
|
d37d0e942c | ||
|
|
98042d8dbd | ||
|
|
af4b826b68 | ||
|
|
253a57ca01 | ||
|
|
caf98b4dfe | ||
|
|
398f71fd00 | ||
|
|
e1301ade96 | ||
|
|
7a23f82192 | ||
|
|
715bcc4aa1 | ||
|
|
0c74838740 | ||
|
|
4b05b6da7b | ||
|
|
375844ff1a | ||
|
|
1d207379cb | ||
|
|
fb49cb71e3 | ||
|
|
9618efbcde | ||
|
|
bb2210b06a | ||
|
|
917052723d | ||
|
|
fef85cadeb | ||
|
|
4a05fb6b28 | ||
|
|
6644ce53f2 | ||
|
|
72f0b89fdc | ||
|
|
41a97a6609 | ||
|
|
38064d6ad5 | ||
|
|
ae6945cedf | ||
|
|
3132d1b032 | ||
|
|
2716ae29bd | ||
|
|
1c50c2b6af | ||
|
|
cf6d16b439 | ||
|
|
60686f55ff | ||
|
|
47d7ace3a7 | ||
|
|
2d3779ec27 | ||
|
|
595071b608 | ||
|
|
57ef717080 | ||
|
|
eb27d1482b | ||
|
|
f57972ead7 | ||
|
|
168eaf538b | ||
|
|
1560455ca3 | ||
|
|
028475a193 | ||
|
|
f7a6dbe39b | ||
|
|
e573a490c9 | ||
|
|
ce3281e70d | ||
|
|
0fbfd160c9 | ||
|
|
20759017e6 | ||
|
|
69e0aab73e | ||
|
|
7ed6733fb7 | ||
|
|
9718ab8579 | ||
|
|
2687a4a018 | ||
|
|
2d9c60dea1 | ||
|
|
841be069b7 | ||
|
|
7833132917 | ||
|
|
e9e63b0983 | ||
|
|
4df470b869 | ||
|
|
89600f6091 | ||
|
|
f986a575e8 | ||
|
|
9c2fe8d21f | ||
|
|
8bcbb9249e | ||
|
|
a95d50c0af | ||
|
|
5db7d3577b | ||
|
|
c53a0ca1c4 | ||
|
|
6fd3d1788a | ||
|
|
087c1975e5 | ||
|
|
3713cbecc3 | ||
|
|
6046789fa4 | ||
|
|
3ea69b180c | ||
|
|
db6e977e3a | ||
|
|
a5c776c846 | ||
|
|
5a566c028a | ||
|
|
ff43c74d8d | ||
|
|
3c7255569c | ||
|
|
4a92ec4d2d | ||
|
|
9bbccb4082 | ||
|
|
4f62314646 | ||
|
|
cb49d0d947 | ||
|
|
89f7874fc6 | ||
|
|
221917e80b | ||
|
|
37d41bd215 | ||
|
|
8a96b8bec4 | ||
|
|
02ee113b95 | ||
|
|
f71dd78915 | ||
|
|
cd5619a05b | ||
|
|
a63a30c76b | ||
|
|
f5ba8be182 | ||
|
|
a9f76322bd | ||
|
|
ed39269c80 | ||
|
|
09426dcd36 | ||
|
|
17941882a9 | ||
|
|
70ab8032a0 | ||
|
|
8360bdc50a | ||
|
|
6837176ec7 | ||
|
|
5e9b4244e7 | ||
|
|
9b6a308958 | ||
|
|
71e327653a | ||
|
|
a56711796f | ||
|
|
09495f2a7c | ||
|
|
484643e114 | ||
|
|
da91aabc35 | ||
|
|
c654398981 | ||
|
|
47a90ec2a1 | ||
|
|
2875e22d0b | ||
|
|
c5d14e0075 | ||
|
|
84e06c363c | ||
|
|
5b9ccc5065 | ||
|
|
6ca1a7ccc7 | ||
|
|
9d666be5d4 | ||
|
|
65de7edcde | ||
|
|
0cdff0d368 | ||
|
|
f87220a908 | ||
|
|
30ea0c6499 | ||
|
|
9501e35c60 | ||
|
|
5ac9d17bdf | ||
|
|
cb14992ddc | ||
|
|
e88372fc8c | ||
|
|
b320662d67 | ||
|
|
ce353cd4d9 | ||
|
|
4befd33866 | ||
|
|
4b36e3ac44 | ||
|
|
f507bc8f9e | ||
|
|
14c88f4a6d | ||
|
|
3e388c2857 | ||
|
|
cfe1209d61 | ||
|
|
5a88a7c22c | ||
|
|
8c661c4401 | ||
|
|
e6f256d640 | ||
|
|
ede354166b | ||
|
|
282a8ce78e | ||
|
|
08fe04f1ee | ||
|
|
082d14a9ba | ||
|
|
617674ce43 | ||
|
|
7088df58dd | ||
|
|
9cbd9b3e44 | ||
|
|
e6586fd360 | ||
|
|
33a6db2599 | ||
|
|
70b0c4f7b9 | ||
|
|
5af3ec4f7b | ||
|
|
79476add12 | ||
|
|
1634a06330 | ||
|
|
a007394f60 | ||
|
|
62a0ba8731 | ||
|
|
e8d3ed1acd | ||
|
|
8b98faa441 | ||
|
|
30320ec9c7 | ||
|
|
5f4a399850 | ||
|
|
82e0d4b0c4 | ||
|
|
95a9df826d | ||
|
|
3b71d26cf3 | ||
|
|
c233ad9b1b | ||
|
|
12d6484b1c | ||
|
|
bc7b1cc6d8 | ||
|
|
ec684348ed | ||
|
|
18a19a3aa2 | ||
|
|
905f2d08c5 | ||
|
|
04947b4d87 | ||
|
|
72bf80533e | ||
|
|
9ddedf926e | ||
|
|
139dd62ff3 | ||
|
|
50ef00526e | ||
|
|
80cf79b9cb | ||
|
|
e6ad39b070 | ||
|
|
56f9c72569 | ||
|
|
dc48c908b8 | ||
|
|
9b0f0e792a | ||
|
|
b3eebb19b6 | ||
|
|
c24589a5be | ||
|
|
1e1c5a4dc8 | ||
|
|
339023421a | ||
|
|
a00d2a431a | ||
|
|
5aca118dbb | ||
|
|
411f7434f4 | ||
|
|
34801382f5 | ||
|
|
b9f2259ae4 | ||
|
|
19020a96bf | ||
|
|
96085147ff | ||
|
|
f3dd344026 | ||
|
|
486096416f | ||
|
|
5710f2e984 | ||
|
|
09936f1f07 | ||
|
|
0d6ca57536 | ||
|
|
3ddcb84db8 | ||
|
|
1012bf063f | ||
|
|
b8155e6182 | ||
|
|
9a34df61bb | ||
|
|
fbb879edf9 | ||
|
|
ac97c88876 | ||
|
|
a1fda2c0de | ||
|
|
f499770d45 | ||
|
|
4769da4ef4 | ||
|
|
c2556a8e39 | ||
|
|
29bf329f6a | ||
|
|
1dee4305bc | ||
|
|
429a98b690 | ||
|
|
da01a146d2 | ||
|
|
dd9f2465be | ||
|
|
b5cf0e2b31 | ||
|
|
1db159ad34 | ||
|
|
6604f973ac | ||
|
|
69ee6582e2 | ||
|
|
6f12667e8c | ||
|
|
b002dff624 | ||
|
|
affef963c1 | ||
|
|
56b2056190 | ||
|
|
c1e6f5126a | ||
|
|
1a8c1ec73d | ||
|
|
52954b8ceb | ||
|
|
a5025e35ea | ||
|
|
07f80c9ebf | ||
|
|
13db23553d | ||
|
|
3963fce43b | ||
|
|
ea4e5147bd | ||
|
|
7a491a4cc5 | ||
|
|
5ba90748f6 | ||
|
|
20f8f22bae | ||
|
|
b50cccac85 | ||
|
|
34ebe9b054 | ||
|
|
43d82cf1a7 | ||
|
|
ab88174091 | ||
|
|
ebcbf85373 | ||
|
|
87513cba6d | ||
|
|
64bcd2f00d | ||
|
|
cc6ae290f8 | ||
|
|
3e62bd3dbd | ||
|
|
8491f9c455 | ||
|
|
3ca754b438 | ||
|
|
8c7c3901e8 | ||
|
|
a9672dfff5 | ||
|
|
203a2ec8b8 | ||
|
|
810cbd1f4f | ||
|
|
49eebcdcbc | ||
|
|
e89021ec3a | ||
|
|
73a697b2fa | ||
|
|
9319d08046 | ||
|
|
7dc5138e91 | ||
|
|
8f189c919a | ||
|
|
906479a15c | ||
|
|
dabbf2037b | ||
|
|
b496147ce7 | ||
|
|
583718f234 | ||
|
|
fdb82f6ec3 | ||
|
|
5145729ab1 | ||
|
|
4d810261a4 | ||
|
|
18e8616834 | ||
|
|
d55563cac5 | ||
|
|
bb481d9bcc | ||
|
|
a163be3584 | ||
|
|
891b7cb2c6 | ||
|
|
176c22f229 | ||
|
|
faa0ed06b6 | ||
|
|
9515db7faf | ||
|
|
d822bf4257 | ||
|
|
0826671809 | ||
|
|
67d74774a9 | ||
|
|
5d65416227 | ||
|
|
49441f62f3 | ||
|
|
99651f6e5b | ||
|
|
edca1f4f89 | ||
|
|
3d834f00f6 | ||
|
|
6bb9e7a766 | ||
|
|
61fb71b1fa | ||
|
|
f8967c376f | ||
|
|
6d3c86c0be | ||
|
|
e42554f892 | ||
|
|
28984090e5 | ||
|
|
251255c746 | ||
|
|
32709dc64c | ||
|
|
71f26a6d81 | ||
|
|
44352f8006 | ||
|
|
af38623590 | ||
|
|
9c1665a759 | ||
|
|
eaad24e5e5 | ||
|
|
cfaf32f71a | ||
|
|
51b235b61a | ||
|
|
0a6d9d4454 | ||
|
|
dc700bbd52 | ||
|
|
cb445825f4 | ||
|
|
4d996e317b | ||
|
|
30c9012004 | ||
|
|
2a23feaf4b | ||
|
|
b82ad3720c | ||
|
|
8d2cb6091e | ||
|
|
3023f33dff | ||
|
|
22e97e981a | ||
|
|
44484e1231 | ||
|
|
eac60b87c7 | ||
|
|
8db28cb76e | ||
|
|
8dbe828b99 | ||
|
|
5c24acd952 | ||
|
|
998b9a5c5d | ||
|
|
0084e9ef26 | ||
|
|
122600bff2 | ||
|
|
41846b6d4c | ||
|
|
dfbcb1489d | ||
|
|
684019c2e3 | ||
|
|
e92619620d | ||
|
|
cebfd12d5c | ||
|
|
874ff01ab8 | ||
|
|
0bb8703f78 | ||
|
|
0bb51aa71d | ||
|
|
af2c1c87e0 | ||
|
|
8939debbc0 | ||
|
|
7591a0ccc6 | ||
|
|
c3ff8182af | ||
|
|
5897c174d3 | ||
|
|
f9a3f4c045 | ||
|
|
a2cb895cdc | ||
|
|
2bebe93e47 | ||
|
|
28ec1869fc | ||
|
|
17f6d7a77b | ||
|
|
9e6e647ff8 | ||
|
|
a2116e5eb5 | ||
|
|
564c9ef712 | ||
|
|
856abb71b7 | ||
|
|
0a30fdea69 | ||
|
|
4f125cf107 | ||
|
|
494d8be777 | ||
|
|
cd9c750884 | ||
|
|
91d319804b | ||
|
|
180eae60f2 | ||
|
|
d01f5c2777 | ||
|
|
294a90a807 | ||
|
|
c3b4ae9c79 | ||
|
|
09188bedf7 | ||
|
|
4614b98e94 | ||
|
|
990bc620f7 | ||
|
|
efb5a92571 | ||
|
|
8e0a96a44c | ||
|
|
43ff2f648c | ||
|
|
4816a09e3a | ||
|
|
3fea92c8b1 | ||
|
|
63f959c951 | ||
|
|
44ba6aadd9 | ||
|
|
d88cf52b4e | ||
|
|
58a00ea24a | ||
|
|
712b23a4bb | ||
|
|
baf836557c | ||
|
|
904b23eeac | ||
|
|
6aafe445f5 | ||
|
|
ebd516855b | ||
|
|
df4e04719e | ||
|
|
2440d922c6 | ||
|
|
f1b8d1c4ad | ||
|
|
79076bda35 | ||
|
|
9d2ea15346 | ||
|
|
77c1113ff7 | ||
|
|
e03ad4cd77 | ||
|
|
6e28517454 | ||
|
|
8ddbf881b3 | ||
|
|
c58516cfb0 | ||
|
|
34758f6205 | ||
|
|
a9959a6f3d | ||
|
|
511c4e696f | ||
|
|
bed7435b0c | ||
|
|
507c1afd59 | ||
|
|
2765487f10 | ||
|
|
80a88811cd | ||
|
|
823195c504 | ||
|
|
0f3e8c7ada | ||
|
|
ee5eb4fc4e | ||
|
|
d58d8074f4 | ||
|
|
94a0530991 | ||
|
|
073af0f89c | ||
|
|
6028b8f186 | ||
|
|
126477ef88 | ||
|
|
13391fd469 | ||
|
|
82e44b01af | ||
|
|
e355fd70ab | ||
|
|
d5c171735e | ||
|
|
b175368794 | ||
|
|
bcf4c25ba8 | ||
|
|
11b09af76d | ||
|
|
af0380a96a | ||
|
|
f39512b4c0 | ||
|
|
7ce62ccaec | ||
|
|
44c0a06996 | ||
|
|
f7d3db06c6 | ||
|
|
0ca37dc707 | ||
|
|
2bcba7b578 | ||
|
|
829e93c079 | ||
|
|
4896563e3c | ||
|
|
0c096d5f07 | ||
|
|
ab8f072388 | ||
|
|
32219e7d32 | ||
|
|
d292e03d1b | ||
|
|
5dd6336953 | ||
|
|
854a244ebb | ||
|
|
125b4b6077 | ||
|
|
46e8d4fad7 | ||
|
|
e5389ffecb | ||
|
|
46509be8a0 | ||
|
|
d3d2ed539f | ||
|
|
8496adc638 | ||
|
|
e1d078a2c3 | ||
|
|
0dee7518c4 | ||
|
|
774f07dd7f | ||
|
|
c271896551 | ||
|
|
82d887f52d | ||
|
|
6e27f877ff | ||
|
|
39a2cab051 | ||
|
|
72d2f4e7e3 | ||
|
|
19bc44a7f3 | ||
|
|
59dc74ffbb | ||
|
|
12c8ab696f | ||
|
|
28f32bd7e5 | ||
|
|
6b43639be5 | ||
|
|
6be80e4827 | ||
|
|
437fb1b16d | ||
|
|
61b6431b6e | ||
|
|
7ccecdd9f7 | ||
|
|
e43b2b5530 | ||
|
|
2cd8b7e021 | ||
|
|
d6768c4c39 | ||
|
|
59a895bfe2 | ||
|
|
cacd957594 | ||
|
|
2cd063ebd6 | ||
|
|
9ed8e49a08 | ||
|
|
66cb7cc21d | ||
|
|
4bf09120ff | ||
|
|
be0769e433 | ||
|
|
7b476e38be | ||
|
|
0a7d3445f4 | ||
|
|
76d2e2c226 | ||
|
|
3007cb86ec | ||
|
|
fa3af372ab | ||
|
|
48a780fc3e | ||
|
|
28df551195 | ||
|
|
e65a71b2ae | ||
|
|
dc61fd2554 | ||
|
|
a4edf266f0 | ||
|
|
7af59ee589 | ||
|
|
3f3c1d6d78 | ||
|
|
ab1d7fd796 | ||
|
|
6c2996a921 | ||
|
|
de32dd8ba4 | ||
|
|
d43e50ee2d | ||
|
|
aec2596262 | ||
|
|
78a7c87ecc | ||
|
|
1d3f8757bc | ||
|
|
c0c69d0739 | ||
|
|
1aa991298a | ||
|
|
f3a3227f21 | ||
|
|
a4c1983657 | ||
|
|
cc28b92935 | ||
|
|
eaa907a647 | ||
|
|
de951fd895 | ||
|
|
3f211d3cc2 | ||
|
|
2f46d512c6 | ||
|
|
12148ec231 | ||
|
|
9fe6af684f | ||
|
|
472bb05e95 | ||
|
|
50bfed706d | ||
|
|
350d8355b1 | ||
|
|
03781d4cec | ||
|
|
67e4afc06e | ||
|
|
32482809b7 | ||
|
|
c315d21be9 | ||
|
|
48b2031269 | ||
|
|
41139b3343 | ||
|
|
d5e6c7b13f | ||
|
|
60d6734e1f | ||
|
|
e684c7d8c4 | ||
|
|
ce35383341 |
2
.dockerignore
Normal file
2
.dockerignore
Normal file
@@ -0,0 +1,2 @@
|
||||
Dockerfile
|
||||
.git/
|
||||
8
.github/workflows/ci.yml
vendored
8
.github/workflows/ci.yml
vendored
@@ -9,18 +9,18 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
os: ["windows-latest", "ubuntu-latest", "macOS-latest"]
|
||||
go: ["1.20.x"]
|
||||
go: ["1.24.x"]
|
||||
runs-on: ${{ matrix.os }}
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 1
|
||||
- uses: WillAbides/setup-go-faster@v1.8.0
|
||||
- uses: actions/setup-go@v6
|
||||
with:
|
||||
go-version: ${{ matrix.go }}
|
||||
- run: "go test -race ./..."
|
||||
- uses: dominikh/staticcheck-action@v1.2.0
|
||||
- uses: dominikh/staticcheck-action@v1.4.0
|
||||
with:
|
||||
version: "2023.1.2"
|
||||
version: "2025.1.1"
|
||||
install-go: false
|
||||
cache-key: ${{ matrix.go }}
|
||||
|
||||
11
.gitignore
vendored
11
.gitignore
vendored
@@ -3,3 +3,14 @@ gon.hcl
|
||||
|
||||
/Build
|
||||
.DS_Store
|
||||
|
||||
# Release folder
|
||||
dist/
|
||||
|
||||
# Binaries
|
||||
ctrld-*
|
||||
|
||||
# generated file
|
||||
cmd/cli/rsrc_*.syso
|
||||
ctrld
|
||||
ctrld.exe
|
||||
|
||||
309
README.md
309
README.md
@@ -4,12 +4,16 @@
|
||||
[](https://pkg.go.dev/github.com/Control-D-Inc/ctrld)
|
||||
[](https://goreportcard.com/report/github.com/Control-D-Inc/ctrld)
|
||||
|
||||

|
||||
|
||||
A highly configurable DNS forwarding proxy with support for:
|
||||
- Multiple listeners for incoming queries
|
||||
- Multiple upstreams with fallbacks
|
||||
- Multiple network policy driven DNS query steering
|
||||
- Multiple network policy driven DNS query steering (via network cidr, MAC address or FQDN)
|
||||
- Policy driven domain based "split horizon" DNS with wildcard support
|
||||
- Integrations with common router vendors and firmware
|
||||
- LAN client discovery via DHCP, mDNS, ARP, NDP, hosts file parsing
|
||||
- Prometheus metrics exporter
|
||||
|
||||
## TLDR
|
||||
Proxy legacy DNS traffic to secure DNS upstreams in highly configurable ways.
|
||||
@@ -31,13 +35,29 @@ All DNS protocols are supported, including:
|
||||
|
||||
## OS Support
|
||||
- Windows (386, amd64, arm)
|
||||
- Mac (amd64, arm64)
|
||||
- Windows Server (386, amd64)
|
||||
- MacOS (amd64, arm64)
|
||||
- Linux (386, amd64, arm, mips)
|
||||
- FreeBSD
|
||||
- Common routers (See Router Mode below)
|
||||
- FreeBSD (386, amd64, arm)
|
||||
- Common routers (See below)
|
||||
|
||||
|
||||
### Supported Routers
|
||||
You can run `ctrld` on any supported router. The list of supported routers and firmware includes:
|
||||
- Asus Merlin
|
||||
- DD-WRT
|
||||
- Firewalla
|
||||
- FreshTomato
|
||||
- GL.iNet
|
||||
- OpenWRT
|
||||
- pfSense / OPNsense
|
||||
- Synology
|
||||
- Ubiquiti (UniFi, EdgeOS)
|
||||
|
||||
`ctrld` will attempt to interface with dnsmasq (or Windows Server) whenever possible and set itself as the upstream, while running on port 5354. On FreeBSD based OSes, `ctrld` will terminate dnsmasq and unbound in order to be able to listen on port 53 directly.
|
||||
|
||||
# Install
|
||||
There are several ways to download and install `ctrld.
|
||||
There are several ways to download and install `ctrld`.
|
||||
|
||||
## Quick Install
|
||||
The simplest way to download and install `ctrld` is to use the following installer command on any UNIX-like platform:
|
||||
@@ -46,30 +66,41 @@ The simplest way to download and install `ctrld` is to use the following install
|
||||
sh -c 'sh -c "$(curl -sL https://api.controld.com/dl)"'
|
||||
```
|
||||
|
||||
Windows user and prefer Powershell (who doesn't)? No problem, execute this command instead in administrative cmd:
|
||||
Windows user and prefer Powershell (who doesn't)? No problem, execute this command instead in administrative PowerShell:
|
||||
```shell
|
||||
powershell -Command "(Invoke-WebRequest -Uri 'https://api.controld.com/dl' -UseBasicParsing).Content | Set-Content 'ctrld_install.bat'" && ctrld_install.bat
|
||||
(Invoke-WebRequest -Uri 'https://api.controld.com/dl/ps1' -UseBasicParsing).Content | Set-Content "$env:TEMPctrld_install.ps1"; Invoke-Expression "& '$env:TEMPctrld_install.ps1'"
|
||||
```
|
||||
|
||||
Or you can pull and run a Docker container from [Docker Hub](https://hub.docker.com/r/controldns/ctrld)
|
||||
```shell
|
||||
docker run -d --name=ctrld -p 127.0.0.1:53:53/tcp -p 127.0.0.1:53:53/udp controldns/ctrld:latest
|
||||
```
|
||||
|
||||
## Download Manually
|
||||
Alternatively, if you know what you're doing you can download pre-compiled binaries from the [Releases](https://github.com/Control-D-Inc/ctrld/releases) section for the appropriate platform.
|
||||
|
||||
## Build
|
||||
Lastly, you can build `ctrld` from source which requires `go1.19+`:
|
||||
Lastly, you can build `ctrld` from source which requires `go1.21+`:
|
||||
|
||||
```shell
|
||||
$ go build ./cmd/ctrld
|
||||
go build ./cmd/ctrld
|
||||
```
|
||||
|
||||
or
|
||||
|
||||
```shell
|
||||
$ go install github.com/Control-D-Inc/ctrld/cmd/ctrld@latest
|
||||
go install github.com/Control-D-Inc/ctrld/cmd/ctrld@latest
|
||||
```
|
||||
|
||||
or
|
||||
|
||||
```shell
|
||||
docker build -t controldns/ctrld . -f docker/Dockerfile
|
||||
```
|
||||
|
||||
|
||||
# Usage
|
||||
The cli is self documenting, so free free to run `--help` on any sub-command to get specific usages.
|
||||
The cli is self documenting, so feel free to run `--help` on any sub-command to get specific usages.
|
||||
|
||||
## Arguments
|
||||
```
|
||||
@@ -85,19 +116,16 @@ Usage:
|
||||
|
||||
Available Commands:
|
||||
run Run the DNS proxy server
|
||||
service Manage ctrld service
|
||||
start Quick start service and configure DNS on interface
|
||||
stop Quick stop service and remove DNS from interface
|
||||
setup Auto-setup Control D on a router.
|
||||
|
||||
Supported platforms:
|
||||
|
||||
ₒ ddwrt
|
||||
ₒ merlin
|
||||
ₒ openwrt
|
||||
ₒ ubios
|
||||
ₒ auto - detect the platform you are running on
|
||||
|
||||
restart Restart the ctrld service
|
||||
reload Reload the ctrld service
|
||||
status Show status of the ctrld service
|
||||
uninstall Stop and uninstall the ctrld service
|
||||
service Manage ctrld service
|
||||
clients Manage clients
|
||||
upgrade Upgrading ctrld to latest version
|
||||
log Manage runtime debug logs
|
||||
|
||||
Flags:
|
||||
-h, --help help for ctrld
|
||||
@@ -109,108 +137,99 @@ Use "ctrld [command] --help" for more information about a command.
|
||||
```
|
||||
|
||||
## Basic Run Mode
|
||||
To start the server with default configuration, simply run: `./ctrld run`. This will create a generic `ctrld.toml` file in the **working directory** and start the application in foreground.
|
||||
1. Start the server
|
||||
```
|
||||
$ sudo ./ctrld run
|
||||
This is the most basic way to run `ctrld`, in foreground mode. Unless you already have a config file, a default one will be generated.
|
||||
|
||||
### Command
|
||||
|
||||
Windows (Admin Shell)
|
||||
```shell
|
||||
ctrld.exe run
|
||||
```
|
||||
|
||||
2. Run a test query using a DNS client, for example, `dig`:
|
||||
Linux or Macos
|
||||
```shell
|
||||
sudo ctrld run
|
||||
```
|
||||
|
||||
You can then run a test query using a DNS client, for example, `dig`:
|
||||
```
|
||||
$ dig verify.controld.com @127.0.0.1 +short
|
||||
api.controld.com.
|
||||
147.185.34.1
|
||||
```
|
||||
|
||||
If `verify.controld.com` resolves, you're successfully using the default Control D upstream. From here, you can start editing the config file and go nuts with it. To enforce a new config, restart the server.
|
||||
If `verify.controld.com` resolves, you're successfully using the default Control D upstream. From here, you can start editing the config file that was generated. To enforce a new config, restart the server.
|
||||
|
||||
## Service Mode
|
||||
To run the application in service mode on any Windows, MacOS or Linux distibution, simply run: `./ctrld start` as system/root user. This will create a generic `ctrld.toml` file in the **user home** directory (on Windows) or `/etc/controld/` (everywhere else), start the system service, and configure the listener on the default network interface. Service will start on OS boot.
|
||||
This mode will run the application as a background system service on any Windows, MacOS, Linux, FreeBSD distribution or supported router. This will create a generic `ctrld.toml` file in the **C:\ControlD** directory (on Windows) or `/etc/controld/` (almost everywhere else), start the system service, and **configure the listener on all physical network interface**. Service will start on OS boot.
|
||||
|
||||
In order to stop the service, and restore your DNS to original state, simply run `./ctrld stop`. If you wish to uninstall the service permanently, run `./ctrld service uninstall`.
|
||||
When Control D upstreams are used on a router type device, `ctrld` will [relay your network topology](https://docs.controld.com/docs/device-clients) to Control D (LAN IPs, MAC addresses, and hostnames), and you will be able to see your LAN devices in the web panel, view analytics and apply unique profiles to them.
|
||||
|
||||
For granular control of the service, run the `service` command. Each sub-command has its own help section so you can see what arguments you can supply.
|
||||
### Command
|
||||
|
||||
```
|
||||
Manage ctrld service
|
||||
Windows (Admin Shell)
|
||||
```shell
|
||||
ctrld.exe start
|
||||
```
|
||||
|
||||
Usage:
|
||||
ctrld service [command]
|
||||
Linux or Macos
|
||||
```
|
||||
sudo ctrld start
|
||||
```
|
||||
|
||||
Available Commands:
|
||||
interfaces Manage network interfaces
|
||||
restart Restart the ctrld service
|
||||
start Start the ctrld service
|
||||
status Show status of the ctrld service
|
||||
stop Stop the ctrld service
|
||||
uninstall Uninstall the ctrld service
|
||||
If `ctrld` is not in your system path (you installed it manually), you will need to run the above commands from the directory where you installed `ctrld`.
|
||||
|
||||
Flags:
|
||||
-h, --help help for service
|
||||
In order to stop the service, and restore your DNS to original state, simply run `ctrld stop`. If you wish to stop and uninstall the service permanently, run `ctrld uninstall`.
|
||||
|
||||
Global Flags:
|
||||
-v, --verbose count verbose log output, "-v" basic logging, "-vv" debug level logging
|
||||
## Unmanaged Service Mode
|
||||
This mode functions similarly to the "Service Mode" above except it will simply start a system service and the config defined listeners, but **will not make any changes to any network interfaces**. You can then set the `ctrld` listener(s) IP on the desired network interfaces manually.
|
||||
|
||||
Use "ctrld service [command] --help" for more information about a command.
|
||||
```
|
||||
### Command
|
||||
|
||||
## Router Mode
|
||||
You can run `ctrld` on any supported router, which will function similarly to the Service Mode mentioned above. The list of supported routers and firmware includes:
|
||||
- Asus Merlin
|
||||
- DD-WRT
|
||||
- FreshTomato
|
||||
- GL.iNet
|
||||
- OpenWRT
|
||||
- pfSense
|
||||
- Synology
|
||||
- Ubiquiti (UniFi, EdgeOS)
|
||||
Windows (Admin Shell)
|
||||
```shell
|
||||
ctrld.exe service start
|
||||
```
|
||||
|
||||
In order to start `ctrld` as a DNS provider, simply run `./ctrld setup auto` command.
|
||||
|
||||
In this mode, and when Control D upstreams are used, the router will [relay your network topology](https://docs.controld.com/docs/device-clients) to Control D (LAN IPs, MAC addresses, and hostnames), and you will be able to see your LAN devices in the web panel, view analytics and apply unique profiles to them.
|
||||
|
||||
### Control D Auto Configuration
|
||||
Application can be started with a specific resolver config, instead of the default one. Simply supply your Resolver ID with a `--cd` flag, when using the `run` (foreground) or `start` (service) or `setup` (router) modes.
|
||||
|
||||
The following command will start the application in foreground mode, using the free "p2" resolver, which blocks Ads & Trackers.
|
||||
|
||||
```shell
|
||||
./ctrld run --cd p2
|
||||
```
|
||||
|
||||
Alternatively, you can use your own personal Control D Device resolver, and start the application in service mode. Your resolver ID is displayed on the "Show Resolvers" screen for the relevant Control D Device.
|
||||
|
||||
```shell
|
||||
./ctrld start --cd abcd1234
|
||||
```
|
||||
|
||||
You can do the same while starting in router mode:
|
||||
```shell
|
||||
./ctrld setup auto --cd abcd1234
|
||||
```
|
||||
|
||||
Once you run the above commands (in service or router modes only), the following things will happen:
|
||||
- You resolver configuration will be fetched from the API, and config file templated with the resolver data
|
||||
- Application will start as a service, and keep running (even after reboot) until you run the `stop` or `service uninstall` sub-commands
|
||||
- Your default network interface will be updated to use the listener started by the service
|
||||
- All OS DNS queries will be sent to the listener
|
||||
Linux or Macos
|
||||
```shell
|
||||
sudo ctrld service start
|
||||
```
|
||||
|
||||
# Configuration
|
||||
See [Configuration Docs](docs/config.md).
|
||||
`ctrld` can be configured in variety of different ways, which include: API, local config file or via cli launch args.
|
||||
|
||||
## Example
|
||||
- Start `listener.0` on 127.0.0.1:53
|
||||
- Accept queries from any source address
|
||||
- Send all queries to `upstream.0` via DoH protocol
|
||||
## API Based Auto Configuration
|
||||
Application can be started with a specific Control D resolver config, instead of the default one. Simply supply your Resolver ID with a `--cd` flag, when using the `start` (service) mode. In this mode, the application will automatically choose a non-conflicting IP and/or port and configure itself as the upstream to whatever process is running on port 53 (like dnsmasq or Windows DNS Server). This mode is used when the 1 liner installer command from the Control D onboarding guide is executed.
|
||||
|
||||
### Default Config
|
||||
The following command will use your own personal Control D Device resolver, and start the application in service mode. Your resolver ID is displayed on the "Show Resolvers" screen for the relevant Control D Endpoint.
|
||||
|
||||
Windows (Admin Shell)
|
||||
```shell
|
||||
ctrld.exe start --cd abcd1234
|
||||
```
|
||||
|
||||
Linux or Macos
|
||||
```shell
|
||||
sudo ctrld start --cd abcd1234
|
||||
```
|
||||
|
||||
Once you run the above command, the following things will happen:
|
||||
- You resolver configuration will be fetched from the API, and config file templated with the resolver data
|
||||
- Application will start as a service, and keep running (even after reboot) until you run the `stop` or `uninstall` sub-commands
|
||||
- All physical network interface will be updated to use the listener started by the service or dnsmasq upstream will be switched to `ctrld`
|
||||
- All DNS queries will be sent to the listener
|
||||
|
||||
## Manual Configuration
|
||||
`ctrld` is entirely config driven and can be configured in many different ways, please see [Configuration Docs](docs/config.md).
|
||||
|
||||
### Example
|
||||
```toml
|
||||
[listener]
|
||||
|
||||
[listener.0]
|
||||
ip = "127.0.0.1"
|
||||
ip = '0.0.0.0'
|
||||
port = 53
|
||||
restricted = false
|
||||
|
||||
[network]
|
||||
|
||||
@@ -218,10 +237,6 @@ See [Configuration Docs](docs/config.md).
|
||||
cidrs = ["0.0.0.0/0"]
|
||||
name = "Network 0"
|
||||
|
||||
[service]
|
||||
log_level = "info"
|
||||
log_path = ""
|
||||
|
||||
[upstream]
|
||||
|
||||
[upstream.0]
|
||||
@@ -230,26 +245,88 @@ See [Configuration Docs](docs/config.md).
|
||||
name = "Control D - Anti-Malware"
|
||||
timeout = 5000
|
||||
type = "doh"
|
||||
|
||||
[upstream.1]
|
||||
bootstrap_ip = "76.76.2.11"
|
||||
endpoint = "p2.freedns.controld.com"
|
||||
name = "Control D - No Ads"
|
||||
timeout = 3000
|
||||
type = "doq"
|
||||
|
||||
```
|
||||
|
||||
## Advanced Configuration
|
||||
The above is the most basic example, which will work out of the box. If you're looking to do advanced configurations using policies, see [Configuration Docs](docs/config.md) for complete documentation of the config file.
|
||||
The above basic config will:
|
||||
- Start listener on 0.0.0.0:53
|
||||
- Accept queries from any source address
|
||||
- Send all queries to `https://freedns.controld.com/p1` using DoH protocol
|
||||
|
||||
You can also supply configuration via launch argeuments, in [Ephemeral Mode](docs/ephemeral_mode.md).
|
||||
## CLI Args
|
||||
If you're unable to use a config file, `ctrld` can be be supplied with basic configuration via launch arguments, in [Ephemeral Mode](docs/ephemeral_mode.md).
|
||||
|
||||
### Example
|
||||
```
|
||||
ctrld run --listen=127.0.0.1:53 --primary_upstream=https://freedns.controld.com/p2 --secondary_upstream=10.0.10.1:53 --domains=*.company.int,very-secure.local --log /path/to/log.log
|
||||
```
|
||||
|
||||
The above will start a foreground process and:
|
||||
- Listen on `127.0.0.1:53` for DNS queries
|
||||
- Forward all queries to `https://freedns.controld.com/p2` using DoH protocol, while...
|
||||
- Excluding `*.company.int` and `very-secure.local` matching queries, that are forwarded to `10.0.10.1:53`
|
||||
- Write a debug log to `/path/to/log.log`
|
||||
|
||||
## DNS Intercept Mode
|
||||
When running `ctrld` alongside VPN software, DNS conflicts can cause intermittent failures, bypassed filtering, or configuration loops. DNS Intercept Mode prevents these issues by transparently capturing all DNS traffic on the system and routing it through `ctrld`, without modifying network adapter DNS settings.
|
||||
|
||||
### When to Use
|
||||
Enable DNS Intercept Mode if you:
|
||||
- Use corporate VPN software (F5, Cisco AnyConnect, Palo Alto GlobalProtect, Zscaler)
|
||||
- Run overlay networks like Tailscale or WireGuard
|
||||
- Experience random DNS failures when VPN connects/disconnects
|
||||
- See gaps in your Control D analytics when VPN is active
|
||||
- Have endpoint security software that also manages DNS
|
||||
|
||||
### Command
|
||||
|
||||
Windows (Admin Shell)
|
||||
```shell
|
||||
ctrld.exe start --intercept-mode dns --cd RESOLVER_ID_HERE
|
||||
```
|
||||
|
||||
macOS
|
||||
```shell
|
||||
sudo ctrld start --intercept-mode dns --cd RESOLVER_ID_HERE
|
||||
```
|
||||
|
||||
`--intercept-mode dns` automatically detects VPN internal domains and routes them to the VPN's DNS server, while Control D handles everything else.
|
||||
|
||||
To disable intercept mode on a service that already has it enabled:
|
||||
|
||||
Windows (Admin Shell)
|
||||
```shell
|
||||
ctrld.exe start --intercept-mode off
|
||||
```
|
||||
|
||||
macOS
|
||||
```shell
|
||||
sudo ctrld start --intercept-mode off
|
||||
```
|
||||
|
||||
This removes the intercept rules and reverts to standard interface-based DNS configuration.
|
||||
|
||||
### Platform Support
|
||||
| Platform | Supported | Mechanism |
|
||||
|----------|-----------|-----------|
|
||||
| Windows | ✅ | NRPT (Name Resolution Policy Table) |
|
||||
| macOS | ✅ | pf (packet filter) redirect |
|
||||
| Linux | ❌ | Not currently supported |
|
||||
|
||||
### Features
|
||||
- **VPN split routing** — VPN-specific domains are automatically detected and forwarded to the VPN's DNS server
|
||||
- **Captive portal recovery** — Wi-Fi login pages (hotels, airports, coffee shops) work automatically
|
||||
- **No network adapter changes** — DNS settings stay untouched, eliminating conflicts entirely
|
||||
- **Automatic port 53 conflict resolution** — if another process (e.g., `mDNSResponder` on macOS) is already using port 53, `ctrld` automatically listens on a different port. OS-level packet interception redirects all DNS traffic to `ctrld` transparently, so no manual configuration is needed. This only applies to intercept mode.
|
||||
|
||||
### Tested VPN Software
|
||||
- F5 BIG-IP APM
|
||||
- Cisco AnyConnect
|
||||
- Palo Alto GlobalProtect
|
||||
- Tailscale (including Exit Nodes)
|
||||
- Windscribe
|
||||
- WireGuard
|
||||
|
||||
For more details, see the [DNS Intercept Mode documentation](https://docs.controld.com/docs/dns-intercept).
|
||||
|
||||
## Contributing
|
||||
See [Contribution Guideline](./docs/contributing.md)
|
||||
|
||||
## Roadmap
|
||||
The following functionality is on the roadmap and will be available in future releases.
|
||||
- Prometheus metrics exporter
|
||||
- DNS intercept mode
|
||||
- Support for more routers (let us know which ones)
|
||||
|
||||
@@ -5,7 +5,18 @@ type ClientInfoCtxKey struct{}
|
||||
|
||||
// ClientInfo represents ctrld's clients information.
|
||||
type ClientInfo struct {
|
||||
Mac string
|
||||
IP string
|
||||
Hostname string
|
||||
Mac string
|
||||
IP string
|
||||
Hostname string
|
||||
Self bool
|
||||
ClientIDPref string
|
||||
}
|
||||
|
||||
// LeaseFileFormat specifies the format of DHCP lease file.
|
||||
type LeaseFileFormat string
|
||||
|
||||
const (
|
||||
Dnsmasq LeaseFileFormat = "dnsmasq"
|
||||
IscDhcpd LeaseFileFormat = "isc-dhcpd"
|
||||
KeaDHCP4 LeaseFileFormat = "kea-dhcp4"
|
||||
)
|
||||
|
||||
4
client_info_darwin.go
Normal file
4
client_info_darwin.go
Normal file
@@ -0,0 +1,4 @@
|
||||
package ctrld
|
||||
|
||||
// SelfDiscover reports whether ctrld should only do self discover.
|
||||
func SelfDiscover() bool { return true }
|
||||
6
client_info_others.go
Normal file
6
client_info_others.go
Normal file
@@ -0,0 +1,6 @@
|
||||
//go:build !windows && !darwin
|
||||
|
||||
package ctrld
|
||||
|
||||
// SelfDiscover reports whether ctrld should only do self discover.
|
||||
func SelfDiscover() bool { return false }
|
||||
18
client_info_windows.go
Normal file
18
client_info_windows.go
Normal file
@@ -0,0 +1,18 @@
|
||||
package ctrld
|
||||
|
||||
import (
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
// isWindowsWorkStation reports whether ctrld was run on a Windows workstation machine.
|
||||
func isWindowsWorkStation() bool {
|
||||
// From https://learn.microsoft.com/en-us/windows/win32/api/winnt/ns-winnt-osversioninfoexa
|
||||
const VER_NT_WORKSTATION = 0x0000001
|
||||
osvi := windows.RtlGetVersion()
|
||||
return osvi.ProductType == VER_NT_WORKSTATION
|
||||
}
|
||||
|
||||
// SelfDiscover reports whether ctrld should only do self discover.
|
||||
func SelfDiscover() bool {
|
||||
return isWindowsWorkStation()
|
||||
}
|
||||
15
cmd/cli/ad_others.go
Normal file
15
cmd/cli/ad_others.go
Normal file
@@ -0,0 +1,15 @@
|
||||
//go:build !windows
|
||||
|
||||
package cli
|
||||
|
||||
import (
|
||||
"github.com/Control-D-Inc/ctrld"
|
||||
)
|
||||
|
||||
// addExtraSplitDnsRule adds split DNS rule if present.
|
||||
func addExtraSplitDnsRule(_ *ctrld.Config) bool { return false }
|
||||
|
||||
// getActiveDirectoryDomain returns AD domain name of this computer.
|
||||
func getActiveDirectoryDomain() (string, error) {
|
||||
return "", nil
|
||||
}
|
||||
74
cmd/cli/ad_windows.go
Normal file
74
cmd/cli/ad_windows.go
Normal file
@@ -0,0 +1,74 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"io"
|
||||
"log"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/microsoft/wmi/pkg/base/host"
|
||||
hh "github.com/microsoft/wmi/pkg/hardware/host"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld"
|
||||
"github.com/Control-D-Inc/ctrld/internal/system"
|
||||
)
|
||||
|
||||
// addExtraSplitDnsRule adds split DNS rule for domain if it's part of active directory.
|
||||
func addExtraSplitDnsRule(cfg *ctrld.Config) bool {
|
||||
domain, err := system.GetActiveDirectoryDomain()
|
||||
if err != nil {
|
||||
mainLog.Load().Debug().Msgf("unable to get active directory domain: %v", err)
|
||||
return false
|
||||
}
|
||||
if domain == "" {
|
||||
mainLog.Load().Debug().Msg("no active directory domain found")
|
||||
return false
|
||||
}
|
||||
// Network rules are lowercase during toml config marshaling,
|
||||
// lowercase the domain here too for consistency.
|
||||
domain = strings.ToLower(domain)
|
||||
domainRuleAdded := addSplitDnsRule(cfg, domain)
|
||||
wildcardDomainRuleRuleAdded := addSplitDnsRule(cfg, "*."+strings.TrimPrefix(domain, "."))
|
||||
return domainRuleAdded || wildcardDomainRuleRuleAdded
|
||||
}
|
||||
|
||||
// addSplitDnsRule adds split-rule for given domain if there's no existed rule.
|
||||
// The return value indicates whether the split-rule was added or not.
|
||||
func addSplitDnsRule(cfg *ctrld.Config, domain string) bool {
|
||||
for n, lc := range cfg.Listener {
|
||||
if lc.Policy == nil {
|
||||
lc.Policy = &ctrld.ListenerPolicyConfig{}
|
||||
}
|
||||
for _, rule := range lc.Policy.Rules {
|
||||
if _, ok := rule[domain]; ok {
|
||||
mainLog.Load().Debug().Msgf("split-rule %q already existed for listener.%s", domain, n)
|
||||
return false
|
||||
}
|
||||
}
|
||||
mainLog.Load().Debug().Msgf("adding split-rule %q for listener.%s", domain, n)
|
||||
lc.Policy.Rules = append(lc.Policy.Rules, ctrld.Rule{domain: []string{}})
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// getActiveDirectoryDomain returns AD domain name of this computer.
|
||||
func getActiveDirectoryDomain() (string, error) {
|
||||
log.SetOutput(io.Discard)
|
||||
defer log.SetOutput(os.Stderr)
|
||||
whost := host.NewWmiLocalHost()
|
||||
cs, err := hh.GetComputerSystem(whost)
|
||||
if cs != nil {
|
||||
defer cs.Close()
|
||||
}
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
pod, err := cs.GetPropertyPartOfDomain()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if pod {
|
||||
return cs.GetPropertyDomain()
|
||||
}
|
||||
return "", nil
|
||||
}
|
||||
73
cmd/cli/ad_windows_test.go
Normal file
73
cmd/cli/ad_windows_test.go
Normal file
@@ -0,0 +1,73 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld"
|
||||
"github.com/Control-D-Inc/ctrld/internal/system"
|
||||
"github.com/Control-D-Inc/ctrld/testhelper"
|
||||
)
|
||||
|
||||
func Test_getActiveDirectoryDomain(t *testing.T) {
|
||||
start := time.Now()
|
||||
domain, err := system.GetActiveDirectoryDomain()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Logf("Using Windows API takes: %d", time.Since(start).Milliseconds())
|
||||
|
||||
start = time.Now()
|
||||
domainPowershell, err := getActiveDirectoryDomainPowershell()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Logf("Using Powershell takes: %d", time.Since(start).Milliseconds())
|
||||
|
||||
if domain != domainPowershell {
|
||||
t.Fatalf("result mismatch, want: %v, got: %v", domainPowershell, domain)
|
||||
}
|
||||
}
|
||||
|
||||
func getActiveDirectoryDomainPowershell() (string, error) {
|
||||
cmd := "$obj = Get-WmiObject Win32_ComputerSystem; if ($obj.PartOfDomain) { $obj.Domain }"
|
||||
output, err := powershell(cmd)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to get domain name: %w, output:\n\n%s", err, string(output))
|
||||
}
|
||||
return string(output), nil
|
||||
}
|
||||
|
||||
func Test_addSplitDnsRule(t *testing.T) {
|
||||
newCfg := func(domains ...string) *ctrld.Config {
|
||||
cfg := testhelper.SampleConfig(t)
|
||||
lc := cfg.Listener["0"]
|
||||
for _, domain := range domains {
|
||||
lc.Policy.Rules = append(lc.Policy.Rules, ctrld.Rule{domain: []string{}})
|
||||
}
|
||||
return cfg
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
cfg *ctrld.Config
|
||||
domain string
|
||||
added bool
|
||||
}{
|
||||
{"added", newCfg(), "example.com", true},
|
||||
{"TLD existed", newCfg("example.com"), "*.example.com", true},
|
||||
{"wildcard existed", newCfg("*.example.com"), "example.com", true},
|
||||
{"not added TLD", newCfg("example.com", "*.example.com"), "example.com", false},
|
||||
{"not added wildcard", newCfg("example.com", "*.example.com"), "*.example.com", false},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
added := addSplitDnsRule(tc.cfg, tc.domain)
|
||||
assert.Equal(t, tc.added, added)
|
||||
})
|
||||
}
|
||||
}
|
||||
5
cmd/cli/cgo.go
Normal file
5
cmd/cli/cgo.go
Normal file
@@ -0,0 +1,5 @@
|
||||
//go:build cgo
|
||||
|
||||
package cli
|
||||
|
||||
const cgoEnabled = true
|
||||
2052
cmd/cli/cli.go
Normal file
2052
cmd/cli/cli.go
Normal file
File diff suppressed because it is too large
Load Diff
46
cmd/cli/cli_test.go
Normal file
46
cmd/cli/cli_test.go
Normal file
@@ -0,0 +1,46 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func Test_writeConfigFile(t *testing.T) {
|
||||
tmpdir := t.TempDir()
|
||||
// simulate --config CLI flag by setting configPath manually.
|
||||
configPath = filepath.Join(tmpdir, "ctrld.toml")
|
||||
_, err := os.Stat(configPath)
|
||||
assert.True(t, os.IsNotExist(err))
|
||||
|
||||
assert.NoError(t, writeConfigFile(&cfg))
|
||||
|
||||
_, err = os.Stat(configPath)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func Test_isStableVersion(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
ver string
|
||||
isStable bool
|
||||
}{
|
||||
{"stable", "v1.3.5", true},
|
||||
{"pre", "v1.3.5-next", false},
|
||||
{"pre with commit hash", "v1.3.5-next-asdf", false},
|
||||
{"dev", "dev", false},
|
||||
{"empty", "dev", false},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
if got := isStableVersion(tc.ver); got != tc.isStable {
|
||||
t.Errorf("unexpected result for %s, want: %v, got: %v", tc.ver, tc.isStable, got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
1606
cmd/cli/commands.go
Normal file
1606
cmd/cli/commands.go
Normal file
File diff suppressed because it is too large
Load Diff
51
cmd/cli/conn.go
Normal file
51
cmd/cli/conn.go
Normal file
@@ -0,0 +1,51 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"net"
|
||||
"time"
|
||||
)
|
||||
|
||||
// logConn wraps a net.Conn, override the Write behavior.
|
||||
// runCmd uses this wrapper, so as long as startCmd finished,
|
||||
// ctrld log won't be flushed with un-necessary write errors.
|
||||
type logConn struct {
|
||||
conn net.Conn
|
||||
}
|
||||
|
||||
func (lc *logConn) Read(b []byte) (n int, err error) {
|
||||
return lc.conn.Read(b)
|
||||
}
|
||||
|
||||
func (lc *logConn) Close() error {
|
||||
return lc.conn.Close()
|
||||
}
|
||||
|
||||
func (lc *logConn) LocalAddr() net.Addr {
|
||||
return lc.conn.LocalAddr()
|
||||
}
|
||||
|
||||
func (lc *logConn) RemoteAddr() net.Addr {
|
||||
return lc.conn.RemoteAddr()
|
||||
}
|
||||
|
||||
func (lc *logConn) SetDeadline(t time.Time) error {
|
||||
return lc.conn.SetDeadline(t)
|
||||
}
|
||||
|
||||
func (lc *logConn) SetReadDeadline(t time.Time) error {
|
||||
return lc.conn.SetReadDeadline(t)
|
||||
}
|
||||
|
||||
func (lc *logConn) SetWriteDeadline(t time.Time) error {
|
||||
return lc.conn.SetWriteDeadline(t)
|
||||
}
|
||||
|
||||
func (lc *logConn) Write(b []byte) (int, error) {
|
||||
// Write performs writes with underlying net.Conn, ignore any errors happen.
|
||||
// "ctrld run" command use this wrapper to report errors to "ctrld start".
|
||||
// If no error occurred, "ctrld start" may finish before "ctrld run" attempt
|
||||
// to close the connection, so ignore errors conservatively here, prevent
|
||||
// un-necessary error "write to closed connection" flushed to ctrld log.
|
||||
_, _ = lc.conn.Write(b)
|
||||
return len(b), nil
|
||||
}
|
||||
44
cmd/cli/control_client.go
Normal file
44
cmd/cli/control_client.go
Normal file
@@ -0,0 +1,44 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
type controlClient struct {
|
||||
c *http.Client
|
||||
}
|
||||
|
||||
func newControlClient(addr string) *controlClient {
|
||||
return &controlClient{c: &http.Client{
|
||||
Transport: &http.Transport{
|
||||
DialContext: func(ctx context.Context, _, _ string) (net.Conn, error) {
|
||||
d := net.Dialer{}
|
||||
return d.DialContext(ctx, "unix", addr)
|
||||
},
|
||||
},
|
||||
Timeout: time.Second * 30,
|
||||
}}
|
||||
}
|
||||
|
||||
func (c *controlClient) post(path string, data io.Reader) (*http.Response, error) {
|
||||
// for log/send, set the timeout to 5 minutes
|
||||
if path == sendLogsPath {
|
||||
c.c.Timeout = time.Minute * 5
|
||||
}
|
||||
return c.c.Post("http://unix"+path, contentTypeJson, data)
|
||||
}
|
||||
|
||||
// postStream sends a POST request with no timeout, suitable for long-lived streaming connections.
|
||||
func (c *controlClient) postStream(path string, data io.Reader) (*http.Response, error) {
|
||||
c.c.Timeout = 0
|
||||
return c.c.Post("http://unix"+path, contentTypeJson, data)
|
||||
}
|
||||
|
||||
// deactivationRequest represents request for validating deactivation pin.
|
||||
type deactivationRequest struct {
|
||||
Pin int64 `json:"pin"`
|
||||
}
|
||||
520
cmd/cli/control_server.go
Normal file
520
cmd/cli/control_server.go
Normal file
@@ -0,0 +1,520 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"reflect"
|
||||
"sort"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/kardianos/service"
|
||||
dto "github.com/prometheus/client_model/go"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld"
|
||||
"github.com/Control-D-Inc/ctrld/internal/controld"
|
||||
)
|
||||
|
||||
const (
|
||||
contentTypeJson = "application/json"
|
||||
listClientsPath = "/clients"
|
||||
startedPath = "/started"
|
||||
reloadPath = "/reload"
|
||||
deactivationPath = "/deactivation"
|
||||
cdPath = "/cd"
|
||||
ifacePath = "/iface"
|
||||
viewLogsPath = "/log/view"
|
||||
sendLogsPath = "/log/send"
|
||||
tailLogsPath = "/log/tail"
|
||||
)
|
||||
|
||||
type ifaceResponse struct {
|
||||
Name string `json:"name"`
|
||||
All bool `json:"all"`
|
||||
OK bool `json:"ok"`
|
||||
InterceptMode string `json:"intercept_mode,omitempty"` // "dns", "hard", or "" (not intercepting)
|
||||
}
|
||||
|
||||
type controlServer struct {
|
||||
server *http.Server
|
||||
mux *http.ServeMux
|
||||
addr string
|
||||
}
|
||||
|
||||
func newControlServer(addr string) (*controlServer, error) {
|
||||
mux := http.NewServeMux()
|
||||
s := &controlServer{
|
||||
server: &http.Server{Handler: mux},
|
||||
mux: mux,
|
||||
}
|
||||
s.addr = addr
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (s *controlServer) start() error {
|
||||
_ = os.Remove(s.addr)
|
||||
unixListener, err := net.Listen("unix", s.addr)
|
||||
if l, ok := unixListener.(*net.UnixListener); ok {
|
||||
l.SetUnlinkOnClose(true)
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
go s.server.Serve(unixListener)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *controlServer) stop() error {
|
||||
_ = os.Remove(s.addr)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*2)
|
||||
defer cancel()
|
||||
return s.server.Shutdown(ctx)
|
||||
}
|
||||
|
||||
func (s *controlServer) register(pattern string, handler http.Handler) {
|
||||
s.mux.Handle(pattern, jsonResponse(handler))
|
||||
}
|
||||
|
||||
func (p *prog) registerControlServerHandler() {
|
||||
p.cs.register(listClientsPath, http.HandlerFunc(func(w http.ResponseWriter, request *http.Request) {
|
||||
mainLog.Load().Debug().Msg("handling list clients request")
|
||||
|
||||
clients := p.ciTable.ListClients()
|
||||
mainLog.Load().Debug().Int("client_count", len(clients)).Msg("retrieved clients list")
|
||||
|
||||
sort.Slice(clients, func(i, j int) bool {
|
||||
return clients[i].IP.Less(clients[j].IP)
|
||||
})
|
||||
mainLog.Load().Debug().Msg("sorted clients by IP address")
|
||||
|
||||
if p.metricsQueryStats.Load() {
|
||||
mainLog.Load().Debug().Msg("metrics query stats enabled, collecting query counts")
|
||||
|
||||
for idx, client := range clients {
|
||||
mainLog.Load().Debug().
|
||||
Int("index", idx).
|
||||
Str("ip", client.IP.String()).
|
||||
Str("mac", client.Mac).
|
||||
Str("hostname", client.Hostname).
|
||||
Msg("processing client metrics")
|
||||
|
||||
client.IncludeQueryCount = true
|
||||
dm := &dto.Metric{}
|
||||
|
||||
if statsClientQueriesCount.MetricVec == nil {
|
||||
mainLog.Load().Debug().
|
||||
Str("client_ip", client.IP.String()).
|
||||
Msg("skipping metrics collection: MetricVec is nil")
|
||||
continue
|
||||
}
|
||||
|
||||
m, err := statsClientQueriesCount.MetricVec.GetMetricWithLabelValues(
|
||||
client.IP.String(),
|
||||
client.Mac,
|
||||
client.Hostname,
|
||||
)
|
||||
if err != nil {
|
||||
mainLog.Load().Debug().
|
||||
Err(err).
|
||||
Str("client_ip", client.IP.String()).
|
||||
Str("mac", client.Mac).
|
||||
Str("hostname", client.Hostname).
|
||||
Msg("failed to get metrics for client")
|
||||
continue
|
||||
}
|
||||
|
||||
if err := m.Write(dm); err == nil && dm.Counter != nil {
|
||||
client.QueryCount = int64(dm.Counter.GetValue())
|
||||
mainLog.Load().Debug().
|
||||
Str("client_ip", client.IP.String()).
|
||||
Int64("query_count", client.QueryCount).
|
||||
Msg("successfully collected query count")
|
||||
} else if err != nil {
|
||||
mainLog.Load().Debug().
|
||||
Err(err).
|
||||
Str("client_ip", client.IP.String()).
|
||||
Msg("failed to write metric")
|
||||
}
|
||||
}
|
||||
} else {
|
||||
mainLog.Load().Debug().Msg("metrics query stats disabled, skipping query counts")
|
||||
}
|
||||
|
||||
if err := json.NewEncoder(w).Encode(&clients); err != nil {
|
||||
mainLog.Load().Error().
|
||||
Err(err).
|
||||
Int("client_count", len(clients)).
|
||||
Msg("failed to encode clients response")
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
mainLog.Load().Debug().
|
||||
Int("client_count", len(clients)).
|
||||
Msg("successfully sent clients list response")
|
||||
}))
|
||||
p.cs.register(startedPath, http.HandlerFunc(func(w http.ResponseWriter, request *http.Request) {
|
||||
select {
|
||||
case <-p.onStartedDone:
|
||||
w.WriteHeader(http.StatusOK)
|
||||
case <-time.After(10 * time.Second):
|
||||
w.WriteHeader(http.StatusRequestTimeout)
|
||||
}
|
||||
}))
|
||||
p.cs.register(reloadPath, http.HandlerFunc(func(w http.ResponseWriter, request *http.Request) {
|
||||
listeners := make(map[string]*ctrld.ListenerConfig)
|
||||
p.mu.Lock()
|
||||
for k, v := range p.cfg.Listener {
|
||||
listeners[k] = &ctrld.ListenerConfig{
|
||||
IP: v.IP,
|
||||
Port: v.Port,
|
||||
}
|
||||
}
|
||||
oldSvc := p.cfg.Service
|
||||
p.mu.Unlock()
|
||||
if err := p.sendReloadSignal(); err != nil {
|
||||
mainLog.Load().Err(err).Msg("could not send reload signal")
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
select {
|
||||
case <-p.reloadDoneCh:
|
||||
case <-time.After(5 * time.Second):
|
||||
http.Error(w, "timeout waiting for ctrld reload", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
// Checking for cases that we could not do a reload.
|
||||
|
||||
// 1. Listener config ip or port changes.
|
||||
for k, v := range p.cfg.Listener {
|
||||
l := listeners[k]
|
||||
if l == nil || l.IP != v.IP || l.Port != v.Port {
|
||||
w.WriteHeader(http.StatusCreated)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// 2. Service config changes.
|
||||
if !reflect.DeepEqual(oldSvc, p.cfg.Service) {
|
||||
w.WriteHeader(http.StatusCreated)
|
||||
return
|
||||
}
|
||||
|
||||
// Otherwise, reload is done.
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
p.cs.register(deactivationPath, http.HandlerFunc(func(w http.ResponseWriter, request *http.Request) {
|
||||
// Non-cd mode always allowing deactivation.
|
||||
if cdUID == "" {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
return
|
||||
}
|
||||
|
||||
// Re-fetch pin code from API.
|
||||
rcReq := &controld.ResolverConfigRequest{
|
||||
RawUID: cdUID,
|
||||
Version: rootCmd.Version,
|
||||
Metadata: ctrld.SystemMetadataRuntime(context.Background()),
|
||||
}
|
||||
if rc, err := controld.FetchResolverConfig(rcReq, cdDev); rc != nil {
|
||||
if rc.DeactivationPin != nil {
|
||||
cdDeactivationPin.Store(*rc.DeactivationPin)
|
||||
} else {
|
||||
cdDeactivationPin.Store(defaultDeactivationPin)
|
||||
}
|
||||
} else {
|
||||
mainLog.Load().Warn().Err(err).Msg("could not re-fetch deactivation pin code")
|
||||
}
|
||||
|
||||
// If pin code not set, allowing deactivation.
|
||||
if !deactivationPinSet() {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
return
|
||||
}
|
||||
|
||||
var req deactivationRequest
|
||||
if err := json.NewDecoder(request.Body).Decode(&req); err != nil {
|
||||
w.WriteHeader(http.StatusPreconditionFailed)
|
||||
mainLog.Load().Err(err).Msg("invalid deactivation request")
|
||||
return
|
||||
}
|
||||
|
||||
code := http.StatusForbidden
|
||||
switch req.Pin {
|
||||
case cdDeactivationPin.Load():
|
||||
code = http.StatusOK
|
||||
select {
|
||||
case p.pinCodeValidCh <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
case defaultDeactivationPin:
|
||||
// If the pin code was set, but users do not provide --pin, return proper code to client.
|
||||
code = http.StatusBadRequest
|
||||
}
|
||||
w.WriteHeader(code)
|
||||
}))
|
||||
p.cs.register(cdPath, http.HandlerFunc(func(w http.ResponseWriter, request *http.Request) {
|
||||
if cdUID != "" {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(cdUID))
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
}))
|
||||
p.cs.register(ifacePath, http.HandlerFunc(func(w http.ResponseWriter, request *http.Request) {
|
||||
res := &ifaceResponse{Name: iface}
|
||||
// p.setDNS is only called when running as a service
|
||||
if !service.Interactive() {
|
||||
<-p.csSetDnsDone
|
||||
if p.csSetDnsOk {
|
||||
res.Name = p.runningIface
|
||||
res.All = p.requiredMultiNICsConfig
|
||||
res.OK = true
|
||||
// Report intercept mode to the start command for proper log output.
|
||||
if interceptMode == "dns" || interceptMode == "hard" {
|
||||
res.InterceptMode = interceptMode
|
||||
}
|
||||
}
|
||||
}
|
||||
if err := json.NewEncoder(w).Encode(res); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
http.Error(w, fmt.Sprintf("could not marshal iface data: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
}))
|
||||
p.cs.register(viewLogsPath, http.HandlerFunc(func(w http.ResponseWriter, request *http.Request) {
|
||||
lr, err := p.logReader()
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
defer lr.r.Close()
|
||||
if lr.size == 0 {
|
||||
w.WriteHeader(http.StatusMovedPermanently)
|
||||
return
|
||||
}
|
||||
data, err := io.ReadAll(lr.r)
|
||||
if err != nil {
|
||||
http.Error(w, fmt.Sprintf("could not read log: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
if err := json.NewEncoder(w).Encode(&logViewResponse{Data: string(data)}); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
http.Error(w, fmt.Sprintf("could not marshal log data: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
}))
|
||||
p.cs.register(sendLogsPath, http.HandlerFunc(func(w http.ResponseWriter, request *http.Request) {
|
||||
if time.Since(p.internalLogSent) < logWriterSentInterval {
|
||||
w.WriteHeader(http.StatusServiceUnavailable)
|
||||
return
|
||||
}
|
||||
r, err := p.logReader()
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if r.size == 0 {
|
||||
w.WriteHeader(http.StatusMovedPermanently)
|
||||
return
|
||||
}
|
||||
req := &controld.LogsRequest{
|
||||
UID: cdUID,
|
||||
Data: r.r,
|
||||
}
|
||||
mainLog.Load().Debug().Msg("sending log file to ControlD server")
|
||||
resp := logSentResponse{Size: r.size}
|
||||
if err := controld.SendLogs(req, cdDev); err != nil {
|
||||
mainLog.Load().Error().Msgf("could not send log file to ControlD server: %v", err)
|
||||
resp.Error = err.Error()
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
} else {
|
||||
mainLog.Load().Debug().Msg("sending log file successfully")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}
|
||||
if err := json.NewEncoder(w).Encode(&resp); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
}
|
||||
p.internalLogSent = time.Now()
|
||||
}))
|
||||
p.cs.register(tailLogsPath, http.HandlerFunc(func(w http.ResponseWriter, request *http.Request) {
|
||||
flusher, ok := w.(http.Flusher)
|
||||
if !ok {
|
||||
http.Error(w, "streaming unsupported", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Determine logging mode and validate before starting the stream.
|
||||
var lw *logWriter
|
||||
useInternalLog := p.needInternalLogging()
|
||||
if useInternalLog {
|
||||
p.mu.Lock()
|
||||
lw = p.internalLogWriter
|
||||
p.mu.Unlock()
|
||||
if lw == nil {
|
||||
w.WriteHeader(http.StatusMovedPermanently)
|
||||
return
|
||||
}
|
||||
} else if p.cfg.Service.LogPath == "" {
|
||||
// No logging configured at all.
|
||||
w.WriteHeader(http.StatusMovedPermanently)
|
||||
return
|
||||
}
|
||||
|
||||
// Parse optional "lines" query param for initial context.
|
||||
numLines := 10
|
||||
if v := request.URL.Query().Get("lines"); v != "" {
|
||||
if n, err := strconv.Atoi(v); err == nil && n >= 0 {
|
||||
numLines = n
|
||||
}
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
||||
w.Header().Set("Transfer-Encoding", "chunked")
|
||||
w.Header().Set("X-Content-Type-Options", "nosniff")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
|
||||
if useInternalLog {
|
||||
// Internal logging mode: subscribe to the logWriter.
|
||||
|
||||
// Send last N lines as initial context.
|
||||
if numLines > 0 {
|
||||
if tail := lw.tailLastLines(numLines); len(tail) > 0 {
|
||||
w.Write(tail)
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
ch, unsub := lw.Subscribe()
|
||||
defer unsub()
|
||||
for {
|
||||
select {
|
||||
case data, ok := <-ch:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if _, err := w.Write(data); err != nil {
|
||||
return
|
||||
}
|
||||
flusher.Flush()
|
||||
case <-request.Context().Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// File-based logging mode: tail the log file.
|
||||
logFile := normalizeLogFilePath(p.cfg.Service.LogPath)
|
||||
f, err := os.Open(logFile)
|
||||
if err != nil {
|
||||
// Already committed 200, just return.
|
||||
return
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
// Seek to show last N lines.
|
||||
if numLines > 0 {
|
||||
if tail := tailFileLastLines(f, numLines); len(tail) > 0 {
|
||||
w.Write(tail)
|
||||
flusher.Flush()
|
||||
}
|
||||
} else {
|
||||
// Seek to end.
|
||||
f.Seek(0, io.SeekEnd)
|
||||
}
|
||||
|
||||
// Poll for new data.
|
||||
buf := make([]byte, 4096)
|
||||
ticker := time.NewTicker(200 * time.Millisecond)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
n, err := f.Read(buf)
|
||||
if n > 0 {
|
||||
if _, werr := w.Write(buf[:n]); werr != nil {
|
||||
return
|
||||
}
|
||||
flusher.Flush()
|
||||
}
|
||||
if err != nil && err != io.EOF {
|
||||
return
|
||||
}
|
||||
case <-request.Context().Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}))
|
||||
}
|
||||
|
||||
// tailFileLastLines reads the last n lines from a file and returns them.
|
||||
// The file position is left at the end of the file after this call.
|
||||
func tailFileLastLines(f *os.File, n int) []byte {
|
||||
stat, err := f.Stat()
|
||||
if err != nil || stat.Size() == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Read from the end in chunks to find the last n lines.
|
||||
const chunkSize = 4096
|
||||
fileSize := stat.Size()
|
||||
var lines []byte
|
||||
offset := fileSize
|
||||
count := 0
|
||||
|
||||
for offset > 0 && count <= n {
|
||||
readSize := int64(chunkSize)
|
||||
if readSize > offset {
|
||||
readSize = offset
|
||||
}
|
||||
offset -= readSize
|
||||
buf := make([]byte, readSize)
|
||||
nRead, err := f.ReadAt(buf, offset)
|
||||
if err != nil && err != io.EOF {
|
||||
break
|
||||
}
|
||||
buf = buf[:nRead]
|
||||
lines = append(buf, lines...)
|
||||
|
||||
// Count newlines in this chunk.
|
||||
for _, b := range buf {
|
||||
if b == '\n' {
|
||||
count++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Trim to last n lines.
|
||||
idx := 0
|
||||
nlCount := 0
|
||||
for i := len(lines) - 1; i >= 0; i-- {
|
||||
if lines[i] == '\n' {
|
||||
nlCount++
|
||||
if nlCount == n+1 {
|
||||
idx = i + 1
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
lines = lines[idx:]
|
||||
|
||||
// Seek to end of file for subsequent reads.
|
||||
f.Seek(0, io.SeekEnd)
|
||||
return lines
|
||||
}
|
||||
|
||||
func jsonResponse(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
54
cmd/cli/control_server_test.go
Normal file
54
cmd/cli/control_server_test.go
Normal file
@@ -0,0 +1,54 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestControlServer(t *testing.T) {
|
||||
f, err := os.CreateTemp("", "")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer os.Remove(f.Name())
|
||||
f.Close()
|
||||
|
||||
s, err := newControlServer(f.Name())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
pattern := "/ping"
|
||||
respBody := []byte("pong")
|
||||
s.register(pattern, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_, _ = w.Write(respBody)
|
||||
}))
|
||||
if err := s.start(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
c := newControlClient(f.Name())
|
||||
resp, err := c.post(pattern, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != 200 {
|
||||
t.Fatalf("unepxected response code: %d", resp.StatusCode)
|
||||
}
|
||||
if ct := resp.Header.Get("content-type"); ct != contentTypeJson {
|
||||
t.Fatalf("unexpected content type: %s", ct)
|
||||
}
|
||||
buf, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !bytes.Equal(buf, respBody) {
|
||||
t.Errorf("unexpected response body, want: %q, got: %q", string(respBody), string(buf))
|
||||
}
|
||||
if err := s.stop(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package main
|
||||
package cli
|
||||
|
||||
//lint:ignore U1000 use in os_linux.go
|
||||
type getDNS func(iface string) []string
|
||||
1818
cmd/cli/dns_intercept_darwin.go
Normal file
1818
cmd/cli/dns_intercept_darwin.go
Normal file
File diff suppressed because it is too large
Load Diff
143
cmd/cli/dns_intercept_darwin_test.go
Normal file
143
cmd/cli/dns_intercept_darwin_test.go
Normal file
@@ -0,0 +1,143 @@
|
||||
//go:build darwin
|
||||
|
||||
package cli
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld"
|
||||
)
|
||||
|
||||
// =============================================================================
|
||||
// buildPFAnchorRules tests
|
||||
// =============================================================================
|
||||
|
||||
func TestPFBuildAnchorRules_Basic(t *testing.T) {
|
||||
p := &prog{cfg: &ctrld.Config{Listener: map[string]*ctrld.ListenerConfig{"0": {IP: "127.0.0.1", Port: 53}}}}
|
||||
rules := p.buildPFAnchorRules(nil)
|
||||
|
||||
// rdr (translation) must come before pass (filtering)
|
||||
rdrIdx := strings.Index(rules, "rdr on lo0 inet proto udp")
|
||||
passRouteIdx := strings.Index(rules, "pass out quick on ! lo0 route-to lo0 inet proto udp")
|
||||
passInIdx := strings.Index(rules, "pass in quick on lo0 reply-to lo0")
|
||||
|
||||
if rdrIdx < 0 {
|
||||
t.Fatal("missing rdr rule")
|
||||
}
|
||||
if passRouteIdx < 0 {
|
||||
t.Fatal("missing pass out route-to rule")
|
||||
}
|
||||
if passInIdx < 0 {
|
||||
t.Fatal("missing pass in on lo0 rule")
|
||||
}
|
||||
if rdrIdx >= passRouteIdx {
|
||||
t.Error("rdr rules must come before pass out route-to rules")
|
||||
}
|
||||
if passRouteIdx >= passInIdx {
|
||||
t.Error("pass out route-to must come before pass in on lo0")
|
||||
}
|
||||
|
||||
// Both UDP and TCP rdr rules
|
||||
if !strings.Contains(rules, "proto udp") || !strings.Contains(rules, "proto tcp") {
|
||||
t.Error("must have both UDP and TCP rdr rules")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPFBuildAnchorRules_WithVPNServers(t *testing.T) {
|
||||
p := &prog{cfg: &ctrld.Config{Listener: map[string]*ctrld.ListenerConfig{"0": {IP: "127.0.0.1", Port: 53}}}}
|
||||
vpnServers := []vpnDNSExemption{
|
||||
{Server: "10.8.0.1"},
|
||||
{Server: "10.8.0.2"},
|
||||
}
|
||||
rules := p.buildPFAnchorRules(vpnServers)
|
||||
|
||||
// VPN exemption rules must appear
|
||||
for _, s := range vpnServers {
|
||||
if !strings.Contains(rules, s.Server) {
|
||||
t.Errorf("missing VPN exemption for %s", s.Server)
|
||||
}
|
||||
}
|
||||
|
||||
// VPN exemptions must come before route-to
|
||||
exemptIdx := strings.Index(rules, "10.8.0.1 port 53 group")
|
||||
routeIdx := strings.Index(rules, "pass out quick on ! lo0 route-to lo0 inet proto udp")
|
||||
if exemptIdx < 0 {
|
||||
t.Fatal("missing VPN exemption rule for 10.8.0.1")
|
||||
}
|
||||
if routeIdx < 0 {
|
||||
t.Fatal("missing route-to rule")
|
||||
}
|
||||
if exemptIdx >= routeIdx {
|
||||
t.Error("VPN exemptions must come before route-to rules")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPFBuildAnchorRules_IPv4AndIPv6VPN(t *testing.T) {
|
||||
p := &prog{cfg: &ctrld.Config{Listener: map[string]*ctrld.ListenerConfig{"0": {IP: "127.0.0.1", Port: 53}}}}
|
||||
vpnServers := []vpnDNSExemption{
|
||||
{Server: "10.8.0.1"},
|
||||
{Server: "fd00::1"},
|
||||
}
|
||||
rules := p.buildPFAnchorRules(vpnServers)
|
||||
|
||||
// IPv4 server should use "inet"
|
||||
lines := strings.Split(rules, "\n")
|
||||
for _, line := range lines {
|
||||
if strings.Contains(line, "10.8.0.1") && strings.HasPrefix(line, "pass") {
|
||||
if !strings.Contains(line, "inet ") {
|
||||
t.Error("IPv4 VPN server rule should contain 'inet'")
|
||||
}
|
||||
if strings.Contains(line, "inet6") {
|
||||
t.Error("IPv4 VPN server rule should not contain 'inet6'")
|
||||
}
|
||||
}
|
||||
if strings.Contains(line, "fd00::1") && strings.HasPrefix(line, "pass") {
|
||||
if !strings.Contains(line, "inet6") {
|
||||
t.Error("IPv6 VPN server rule should contain 'inet6'")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestPFBuildAnchorRules_Ordering(t *testing.T) {
|
||||
p := &prog{cfg: &ctrld.Config{Listener: map[string]*ctrld.ListenerConfig{"0": {IP: "127.0.0.1", Port: 53}}}}
|
||||
vpnServers := []vpnDNSExemption{
|
||||
{Server: "10.8.0.1"},
|
||||
}
|
||||
rules := p.buildPFAnchorRules(vpnServers)
|
||||
|
||||
// Verify ordering: rdr → exemptions → route-to → pass in on lo0
|
||||
rdrIdx := strings.Index(rules, "rdr on lo0 inet proto udp")
|
||||
exemptIdx := strings.Index(rules, "pass out quick on ! lo0 inet proto { udp, tcp } from any to 10.8.0.1 port 53 group _ctrld")
|
||||
routeIdx := strings.Index(rules, "pass out quick on ! lo0 route-to lo0 inet proto udp")
|
||||
passInIdx := strings.Index(rules, "pass in quick on lo0 reply-to lo0")
|
||||
|
||||
if rdrIdx < 0 || exemptIdx < 0 || routeIdx < 0 || passInIdx < 0 {
|
||||
t.Fatalf("missing expected rules: rdr=%d exempt=%d route=%d passIn=%d", rdrIdx, exemptIdx, routeIdx, passInIdx)
|
||||
}
|
||||
|
||||
if !(rdrIdx < exemptIdx && exemptIdx < routeIdx && routeIdx < passInIdx) {
|
||||
t.Errorf("incorrect rule ordering: rdr(%d) < exempt(%d) < route(%d) < passIn(%d)", rdrIdx, exemptIdx, routeIdx, passInIdx)
|
||||
}
|
||||
}
|
||||
|
||||
// TestPFAddressFamily tests the pfAddressFamily helper.
|
||||
func TestPFAddressFamily(t *testing.T) {
|
||||
tests := []struct {
|
||||
ip string
|
||||
want string
|
||||
}{
|
||||
{"10.0.0.1", "inet"},
|
||||
{"192.168.1.1", "inet"},
|
||||
{"127.0.0.1", "inet"},
|
||||
{"::1", "inet6"},
|
||||
{"fd00::1", "inet6"},
|
||||
{"2001:db8::1", "inet6"},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
if got := pfAddressFamily(tt.ip); got != tt.want {
|
||||
t.Errorf("pfAddressFamily(%q) = %q, want %q", tt.ip, got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
39
cmd/cli/dns_intercept_others.go
Normal file
39
cmd/cli/dns_intercept_others.go
Normal file
@@ -0,0 +1,39 @@
|
||||
//go:build !windows && !darwin
|
||||
|
||||
package cli
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// startDNSIntercept is not supported on this platform.
|
||||
// DNS intercept mode is only available on Windows (via WFP) and macOS (via pf).
|
||||
func (p *prog) startDNSIntercept() error {
|
||||
return fmt.Errorf("dns intercept: not supported on this platform (only Windows and macOS)")
|
||||
}
|
||||
|
||||
// stopDNSIntercept is a no-op on unsupported platforms.
|
||||
func (p *prog) stopDNSIntercept() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// exemptVPNDNSServers is a no-op on unsupported platforms.
|
||||
func (p *prog) exemptVPNDNSServers(exemptions []vpnDNSExemption) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// ensurePFAnchorActive is a no-op on unsupported platforms.
|
||||
func (p *prog) ensurePFAnchorActive() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// checkTunnelInterfaceChanges is a no-op on unsupported platforms.
|
||||
func (p *prog) checkTunnelInterfaceChanges() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// scheduleDelayedRechecks is a no-op on unsupported platforms.
|
||||
func (p *prog) scheduleDelayedRechecks() {}
|
||||
|
||||
// pfInterceptMonitor is a no-op on unsupported platforms.
|
||||
func (p *prog) pfInterceptMonitor() {}
|
||||
1639
cmd/cli/dns_intercept_windows.go
Normal file
1639
cmd/cli/dns_intercept_windows.go
Normal file
File diff suppressed because it is too large
Load Diff
1953
cmd/cli/dns_proxy.go
Normal file
1953
cmd/cli/dns_proxy.go
Normal file
File diff suppressed because it is too large
Load Diff
476
cmd/cli/dns_proxy_test.go
Normal file
476
cmd/cli/dns_proxy_test.go
Normal file
@@ -0,0 +1,476 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld"
|
||||
"github.com/Control-D-Inc/ctrld/internal/dnscache"
|
||||
"github.com/Control-D-Inc/ctrld/testhelper"
|
||||
)
|
||||
|
||||
func Test_wildcardMatches(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
wildcard string
|
||||
domain string
|
||||
match bool
|
||||
}{
|
||||
{"domain - prefix parent should not match", "*.example.com", "example.com", false},
|
||||
{"domain - prefix", "*.example.com", "anything.example.com", true},
|
||||
{"domain - prefix not match other s", "*.example.com", "other.org", false},
|
||||
{"domain - prefix not match s in name", "*.example.com", "eexample.com", false},
|
||||
{"domain - suffix", "suffix.*", "suffix.example.com", true},
|
||||
{"domain - suffix not match other", "suffix.*", "suffix1.example.com", false},
|
||||
{"domain - both", "suffix.*.example.com", "suffix.anything.example.com", true},
|
||||
{"domain - both not match", "suffix.*.example.com", "suffix1.suffix.example.com", false},
|
||||
{"domain - case-insensitive", "*.EXAMPLE.com", "anything.example.com", true},
|
||||
{"mac - prefix", "*:98:05:b4:2b", "d4:67:98:05:b4:2b", true},
|
||||
{"mac - prefix not match other s", "*:98:05:b4:2b", "0d:ba:54:09:94:2c", false},
|
||||
{"mac - prefix not match s in name", "*:98:05:b4:2b", "e4:67:97:05:b4:2b", false},
|
||||
{"mac - suffix", "d4:67:98:*", "d4:67:98:05:b4:2b", true},
|
||||
{"mac - suffix not match other", "d4:67:98:*", "d4:67:97:15:b4:2b", false},
|
||||
{"mac - both", "d4:67:98:*:b4:2b", "d4:67:98:05:b4:2b", true},
|
||||
{"mac - both not match", "d4:67:98:*:b4:2b", "d4:67:97:05:c4:2b", false},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
if got := wildcardMatches(tc.wildcard, tc.domain); got != tc.match {
|
||||
t.Errorf("unexpected result, wildcard: %s, domain: %s, want: %v, got: %v", tc.wildcard, tc.domain, tc.match, got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_canonicalName(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
domain string
|
||||
canonical string
|
||||
}{
|
||||
{"fqdn to canonical", "example.com.", "example.com"},
|
||||
{"already canonical", "example.com", "example.com"},
|
||||
{"case insensitive", "Example.Com.", "example.com"},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
if got := canonicalName(tc.domain); got != tc.canonical {
|
||||
t.Errorf("unexpected result, want: %s, got: %s", tc.canonical, got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_prog_upstreamFor(t *testing.T) {
|
||||
cfg := testhelper.SampleConfig(t)
|
||||
cfg.Service.LeakOnUpstreamFailure = func(v bool) *bool { return &v }(false)
|
||||
p := &prog{cfg: cfg}
|
||||
p.um = newUpstreamMonitor(p.cfg)
|
||||
p.lanLoopGuard = newLoopGuard()
|
||||
p.ptrLoopGuard = newLoopGuard()
|
||||
for _, nc := range p.cfg.Network {
|
||||
for _, cidr := range nc.Cidrs {
|
||||
_, ipNet, err := net.ParseCIDR(cidr)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
nc.IPNets = append(nc.IPNets, ipNet)
|
||||
}
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
ip string
|
||||
mac string
|
||||
defaultUpstreamNum string
|
||||
lc *ctrld.ListenerConfig
|
||||
domain string
|
||||
upstreams []string
|
||||
matched bool
|
||||
testLogMsg string
|
||||
}{
|
||||
{"Policy map matches", "192.168.0.1:0", "", "0", p.cfg.Listener["0"], "abc.xyz", []string{"upstream.1", "upstream.0"}, true, ""},
|
||||
{"Policy split matches", "192.168.0.1:0", "", "0", p.cfg.Listener["0"], "abc.ru", []string{"upstream.1"}, true, ""},
|
||||
{"Policy map for other network matches", "192.168.1.2:0", "", "0", p.cfg.Listener["0"], "abc.xyz", []string{"upstream.0"}, true, ""},
|
||||
{"No policy map for listener", "192.168.1.2:0", "", "1", p.cfg.Listener["1"], "abc.ru", []string{"upstream.1"}, false, ""},
|
||||
{"unenforced loging", "192.168.1.2:0", "", "0", p.cfg.Listener["0"], "abc.ru", []string{"upstream.1"}, true, "My Policy, network.1 (unenforced), *.ru -> [upstream.1]"},
|
||||
{"Policy Macs matches upper", "192.168.0.1:0", "14:45:A0:67:83:0A", "0", p.cfg.Listener["0"], "abc.xyz", []string{"upstream.2"}, true, "14:45:a0:67:83:0a"},
|
||||
{"Policy Macs matches lower", "192.168.0.1:0", "14:54:4a:8e:08:2d", "0", p.cfg.Listener["0"], "abc.xyz", []string{"upstream.2"}, true, "14:54:4a:8e:08:2d"},
|
||||
{"Policy Macs matches case-insensitive", "192.168.0.1:0", "14:54:4A:8E:08:2D", "0", p.cfg.Listener["0"], "abc.xyz", []string{"upstream.2"}, true, "14:54:4a:8e:08:2d"},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
for _, network := range []string{"udp", "tcp"} {
|
||||
var (
|
||||
addr net.Addr
|
||||
err error
|
||||
)
|
||||
switch network {
|
||||
case "udp":
|
||||
addr, err = net.ResolveUDPAddr(network, tc.ip)
|
||||
case "tcp":
|
||||
addr, err = net.ResolveTCPAddr(network, tc.ip)
|
||||
}
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, addr)
|
||||
ctx := context.WithValue(context.Background(), ctrld.ReqIdCtxKey{}, requestID())
|
||||
ufr := p.upstreamFor(ctx, tc.defaultUpstreamNum, tc.lc, addr, tc.mac, tc.domain)
|
||||
p.proxy(ctx, &proxyRequest{
|
||||
msg: newDnsMsgWithHostname("foo", dns.TypeA),
|
||||
ufr: ufr,
|
||||
})
|
||||
assert.Equal(t, tc.matched, ufr.matched)
|
||||
assert.Equal(t, tc.upstreams, ufr.upstreams)
|
||||
if tc.testLogMsg != "" {
|
||||
assert.Contains(t, logOutput.String(), tc.testLogMsg)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCache(t *testing.T) {
|
||||
cfg := testhelper.SampleConfig(t)
|
||||
prog := &prog{cfg: cfg}
|
||||
for _, nc := range prog.cfg.Network {
|
||||
for _, cidr := range nc.Cidrs {
|
||||
_, ipNet, err := net.ParseCIDR(cidr)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
nc.IPNets = append(nc.IPNets, ipNet)
|
||||
}
|
||||
}
|
||||
cacher, err := dnscache.NewLRUCache(4096)
|
||||
require.NoError(t, err)
|
||||
prog.cache = cacher
|
||||
|
||||
msg := new(dns.Msg)
|
||||
msg.SetQuestion("example.com", dns.TypeA)
|
||||
msg.MsgHdr.RecursionDesired = true
|
||||
answer1 := new(dns.Msg)
|
||||
answer1.SetRcode(msg, dns.RcodeSuccess)
|
||||
|
||||
prog.cache.Add(dnscache.NewKey(msg, "upstream.1"), dnscache.NewValue(answer1, time.Now().Add(time.Minute)))
|
||||
answer2 := new(dns.Msg)
|
||||
answer2.SetRcode(msg, dns.RcodeRefused)
|
||||
prog.cache.Add(dnscache.NewKey(msg, "upstream.0"), dnscache.NewValue(answer2, time.Now().Add(time.Minute)))
|
||||
|
||||
req1 := &proxyRequest{
|
||||
msg: msg,
|
||||
ci: nil,
|
||||
failoverRcodes: nil,
|
||||
ufr: &upstreamForResult{
|
||||
upstreams: []string{"upstream.1"},
|
||||
matchedPolicy: "",
|
||||
matchedNetwork: "",
|
||||
matchedRule: "",
|
||||
matched: false,
|
||||
},
|
||||
}
|
||||
req2 := &proxyRequest{
|
||||
msg: msg,
|
||||
ci: nil,
|
||||
failoverRcodes: nil,
|
||||
ufr: &upstreamForResult{
|
||||
upstreams: []string{"upstream.0"},
|
||||
matchedPolicy: "",
|
||||
matchedNetwork: "",
|
||||
matchedRule: "",
|
||||
matched: false,
|
||||
},
|
||||
}
|
||||
got1 := prog.proxy(context.Background(), req1)
|
||||
got2 := prog.proxy(context.Background(), req2)
|
||||
assert.NotSame(t, got1, got2)
|
||||
assert.Equal(t, answer1.Rcode, got1.answer.Rcode)
|
||||
assert.Equal(t, answer2.Rcode, got2.answer.Rcode)
|
||||
}
|
||||
|
||||
func Test_ipAndMacFromMsg(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
ip string
|
||||
wantIp bool
|
||||
mac string
|
||||
wantMac bool
|
||||
}{
|
||||
{"has ip v4 and mac", "1.2.3.4", true, "4c:20:b8:ab:87:1b", true},
|
||||
{"has ip v6 and mac", "2606:1a40:3::1", true, "4c:20:b8:ab:87:1b", true},
|
||||
{"no ip", "1.2.3.4", false, "4c:20:b8:ab:87:1b", false},
|
||||
{"no mac", "1.2.3.4", false, "4c:20:b8:ab:87:1b", false},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ip := net.ParseIP(tc.ip)
|
||||
if ip == nil {
|
||||
t.Fatal("missing IP")
|
||||
}
|
||||
hw, err := net.ParseMAC(tc.mac)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
m := new(dns.Msg)
|
||||
m.SetQuestion("example.com.", dns.TypeA)
|
||||
o := &dns.OPT{Hdr: dns.RR_Header{Name: ".", Rrtype: dns.TypeOPT}}
|
||||
if tc.wantMac {
|
||||
ec1 := &dns.EDNS0_LOCAL{Code: EDNS0_OPTION_MAC, Data: hw}
|
||||
o.Option = append(o.Option, ec1)
|
||||
}
|
||||
if tc.wantIp {
|
||||
ec2 := &dns.EDNS0_SUBNET{Address: ip}
|
||||
o.Option = append(o.Option, ec2)
|
||||
}
|
||||
m.Extra = append(m.Extra, o)
|
||||
gotIP, gotMac := ipAndMacFromMsg(m)
|
||||
if tc.wantMac && gotMac != tc.mac {
|
||||
t.Errorf("mismatch, want: %q, got: %q", tc.mac, gotMac)
|
||||
}
|
||||
if !tc.wantMac && gotMac != "" {
|
||||
t.Errorf("unexpected mac: %q", gotMac)
|
||||
}
|
||||
if tc.wantIp && gotIP != tc.ip {
|
||||
t.Errorf("mismatch, want: %q, got: %q", tc.ip, gotIP)
|
||||
}
|
||||
if !tc.wantIp && gotIP != "" {
|
||||
t.Errorf("unexpected ip: %q", gotIP)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_remoteAddrFromMsg(t *testing.T) {
|
||||
loopbackIP := net.ParseIP("127.0.0.1")
|
||||
tests := []struct {
|
||||
name string
|
||||
addr net.Addr
|
||||
ci *ctrld.ClientInfo
|
||||
want string
|
||||
}{
|
||||
{"tcp", &net.TCPAddr{IP: loopbackIP, Port: 12345}, &ctrld.ClientInfo{IP: "192.168.1.10"}, "192.168.1.10:12345"},
|
||||
{"udp", &net.UDPAddr{IP: loopbackIP, Port: 12345}, &ctrld.ClientInfo{IP: "192.168.1.11"}, "192.168.1.11:12345"},
|
||||
{"nil client info", &net.UDPAddr{IP: loopbackIP, Port: 12345}, nil, "127.0.0.1:12345"},
|
||||
{"empty ip", &net.UDPAddr{IP: loopbackIP, Port: 12345}, &ctrld.ClientInfo{}, "127.0.0.1:12345"},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
addr := spoofRemoteAddr(tc.addr, tc.ci)
|
||||
if addr.String() != tc.want {
|
||||
t.Errorf("unexpected result, want: %q, got: %q", tc.want, addr.String())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_ipFromARPA(t *testing.T) {
|
||||
tests := []struct {
|
||||
IP string
|
||||
ARPA string
|
||||
}{
|
||||
{"1.2.3.4", "4.3.2.1.in-addr.arpa."},
|
||||
{"245.110.36.114", "114.36.110.245.in-addr.arpa."},
|
||||
{"::ffff:12.34.56.78", "78.56.34.12.in-addr.arpa."},
|
||||
{"::1", "1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.ip6.arpa."},
|
||||
{"1::", "0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.1.0.0.0.ip6.arpa."},
|
||||
{"1234:567::89a:bcde", "e.d.c.b.a.9.8.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.7.6.5.0.4.3.2.1.ip6.arpa."},
|
||||
{"1234:567:fefe:bcbc:adad:9e4a:89a:bcde", "e.d.c.b.a.9.8.0.a.4.e.9.d.a.d.a.c.b.c.b.e.f.e.f.7.6.5.0.4.3.2.1.ip6.arpa."},
|
||||
{"", "asd.in-addr.arpa."},
|
||||
{"", "asd.ip6.arpa."},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.IP, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
if got := ipFromARPA(tc.ARPA); !got.Equal(net.ParseIP(tc.IP)) {
|
||||
t.Errorf("unexpected ip, want: %s, got: %s", tc.IP, got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func newDnsMsgWithClientIP(ip string) *dns.Msg {
|
||||
m := new(dns.Msg)
|
||||
m.SetQuestion("example.com.", dns.TypeA)
|
||||
o := &dns.OPT{Hdr: dns.RR_Header{Name: ".", Rrtype: dns.TypeOPT}}
|
||||
o.Option = append(o.Option, &dns.EDNS0_SUBNET{Address: net.ParseIP(ip)})
|
||||
m.Extra = append(m.Extra, o)
|
||||
return m
|
||||
}
|
||||
|
||||
func Test_stripClientSubnet(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
msg *dns.Msg
|
||||
wantSubnet bool
|
||||
}{
|
||||
{"no edns0", new(dns.Msg), false},
|
||||
{"loopback IP v4", newDnsMsgWithClientIP("127.0.0.1"), false},
|
||||
{"loopback IP v6", newDnsMsgWithClientIP("::1"), false},
|
||||
{"private IP v4", newDnsMsgWithClientIP("192.168.1.123"), false},
|
||||
{"private IP v6", newDnsMsgWithClientIP("fd12:3456:789a:1::1"), false},
|
||||
{"public IP", newDnsMsgWithClientIP("1.1.1.1"), true},
|
||||
{"invalid IP", newDnsMsgWithClientIP(""), true},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
stripClientSubnet(tc.msg)
|
||||
hasSubnet := false
|
||||
if opt := tc.msg.IsEdns0(); opt != nil {
|
||||
for _, s := range opt.Option {
|
||||
if _, ok := s.(*dns.EDNS0_SUBNET); ok {
|
||||
hasSubnet = true
|
||||
}
|
||||
}
|
||||
}
|
||||
if tc.wantSubnet != hasSubnet {
|
||||
t.Errorf("unexpected result, want: %v, got: %v", tc.wantSubnet, hasSubnet)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func newDnsMsgWithHostname(hostname string, typ uint16) *dns.Msg {
|
||||
m := new(dns.Msg)
|
||||
m.SetQuestion(hostname, typ)
|
||||
return m
|
||||
}
|
||||
|
||||
func Test_isLanHostnameQuery(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
msg *dns.Msg
|
||||
isLanHostnameQuery bool
|
||||
}{
|
||||
{"A", newDnsMsgWithHostname("foo", dns.TypeA), true},
|
||||
{"AAAA", newDnsMsgWithHostname("foo", dns.TypeAAAA), true},
|
||||
{"A not LAN", newDnsMsgWithHostname("example.com", dns.TypeA), false},
|
||||
{"AAAA not LAN", newDnsMsgWithHostname("example.com", dns.TypeAAAA), false},
|
||||
{"Not A or AAAA", newDnsMsgWithHostname("foo", dns.TypeTXT), false},
|
||||
{".domain", newDnsMsgWithHostname("foo.domain", dns.TypeA), true},
|
||||
{".lan", newDnsMsgWithHostname("foo.lan", dns.TypeA), true},
|
||||
{".local", newDnsMsgWithHostname("foo.local", dns.TypeA), true},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
if got := isLanHostnameQuery(tc.msg); tc.isLanHostnameQuery != got {
|
||||
t.Errorf("unexpected result, want: %v, got: %v", tc.isLanHostnameQuery, got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func newDnsMsgPtr(ip string, t *testing.T) *dns.Msg {
|
||||
t.Helper()
|
||||
m := new(dns.Msg)
|
||||
ptr, err := dns.ReverseAddr(ip)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
m.SetQuestion(ptr, dns.TypePTR)
|
||||
return m
|
||||
}
|
||||
|
||||
func Test_isPrivatePtrLookup(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
msg *dns.Msg
|
||||
isPrivatePtrLookup bool
|
||||
}{
|
||||
// RFC 1918 allocates 10.0.0.0/8, 172.16.0.0/12, and 192.168.0.0/16 as
|
||||
{"10.0.0.0/8", newDnsMsgPtr("10.0.0.123", t), true},
|
||||
{"172.16.0.0/12", newDnsMsgPtr("172.16.0.123", t), true},
|
||||
{"192.168.0.0/16", newDnsMsgPtr("192.168.1.123", t), true},
|
||||
{"CGNAT", newDnsMsgPtr("100.66.27.28", t), true},
|
||||
{"Loopback", newDnsMsgPtr("127.0.0.1", t), true},
|
||||
{"Link Local Unicast", newDnsMsgPtr("fe80::69f6:e16e:8bdb:433f", t), true},
|
||||
{"Public IP", newDnsMsgPtr("8.8.8.8", t), false},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
if got := isPrivatePtrLookup(tc.msg); tc.isPrivatePtrLookup != got {
|
||||
t.Errorf("unexpected result, want: %v, got: %v", tc.isPrivatePtrLookup, got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_isSrvLanLookup(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
msg *dns.Msg
|
||||
isSrvLookup bool
|
||||
}{
|
||||
{"SRV LAN", newDnsMsgWithHostname("foo", dns.TypeSRV), true},
|
||||
{"Not SRV", newDnsMsgWithHostname("foo", dns.TypeNone), false},
|
||||
{"Not SRV LAN", newDnsMsgWithHostname("controld.com", dns.TypeSRV), false},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
if got := isSrvLanLookup(tc.msg); tc.isSrvLookup != got {
|
||||
t.Errorf("unexpected result, want: %v, got: %v", tc.isSrvLookup, got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_isWanClient(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
addr net.Addr
|
||||
isWanClient bool
|
||||
}{
|
||||
// RFC 1918 allocates 10.0.0.0/8, 172.16.0.0/12, and 192.168.0.0/16 as
|
||||
{"10.0.0.0/8", &net.UDPAddr{IP: net.ParseIP("10.0.0.123")}, false},
|
||||
{"172.16.0.0/12", &net.UDPAddr{IP: net.ParseIP("172.16.0.123")}, false},
|
||||
{"192.168.0.0/16", &net.UDPAddr{IP: net.ParseIP("192.168.1.123")}, false},
|
||||
{"CGNAT", &net.UDPAddr{IP: net.ParseIP("100.66.27.28")}, false},
|
||||
{"Loopback", &net.UDPAddr{IP: net.ParseIP("127.0.0.1")}, false},
|
||||
{"Link Local Unicast", &net.UDPAddr{IP: net.ParseIP("fe80::69f6:e16e:8bdb:433f")}, false},
|
||||
{"Public", &net.UDPAddr{IP: net.ParseIP("8.8.8.8")}, true},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
if got := isWanClient(tc.addr); tc.isWanClient != got {
|
||||
t.Errorf("unexpected result, want: %v, got: %v", tc.isWanClient, got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_prog_queryFromSelf(t *testing.T) {
|
||||
p := &prog{}
|
||||
require.NotPanics(t, func() {
|
||||
p.queryFromSelf("")
|
||||
})
|
||||
require.NotPanics(t, func() {
|
||||
p.queryFromSelf("foo")
|
||||
})
|
||||
}
|
||||
14
cmd/cli/hostname.go
Normal file
14
cmd/cli/hostname.go
Normal file
@@ -0,0 +1,14 @@
|
||||
package cli
|
||||
|
||||
import "regexp"
|
||||
|
||||
// validHostname reports whether hostname is a valid hostname.
|
||||
// A valid hostname contains 3 -> 64 characters and conform to RFC1123.
|
||||
func validHostname(hostname string) bool {
|
||||
hostnameLen := len(hostname)
|
||||
if hostnameLen < 3 || hostnameLen > 64 {
|
||||
return false
|
||||
}
|
||||
validHostnameRfc1123 := regexp.MustCompile(`^(([a-zA-Z0-9]|[a-zA-Z0-9][a-zA-Z0-9\-]*[a-zA-Z0-9])\.)*([A-Za-z0-9]|[A-Za-z0-9][A-Za-z0-9\-]*[A-Za-z0-9])$`)
|
||||
return validHostnameRfc1123.MatchString(hostname)
|
||||
}
|
||||
35
cmd/cli/hostname_test.go
Normal file
35
cmd/cli/hostname_test.go
Normal file
@@ -0,0 +1,35 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func Test_validHostname(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
hostname string
|
||||
valid bool
|
||||
}{
|
||||
{"localhost", "localhost", true},
|
||||
{"localdomain", "localhost.localdomain", true},
|
||||
{"localhost6", "localhost6.localdomain6", true},
|
||||
{"ip6", "ip6-localhost", true},
|
||||
{"non-domain", "controld", true},
|
||||
{"domain", "controld.com", true},
|
||||
{"empty", "", false},
|
||||
{"min length", "fo", false},
|
||||
{"max length", strings.Repeat("a", 65), false},
|
||||
{"special char", "foo!", false},
|
||||
{"non-ascii", "fooΩ", false},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.hostname, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
assert.True(t, validHostname(tc.hostname) == tc.valid)
|
||||
})
|
||||
}
|
||||
}
|
||||
95
cmd/cli/library.go
Normal file
95
cmd/cli/library.go
Normal file
@@ -0,0 +1,95 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
// AppCallback provides hooks for injecting certain functionalities
|
||||
// from mobile platforms to main ctrld cli.
|
||||
type AppCallback struct {
|
||||
HostName func() string
|
||||
LanIp func() string
|
||||
MacAddress func() string
|
||||
Exit func(error string)
|
||||
}
|
||||
|
||||
// AppConfig allows overwriting ctrld cli flags from mobile platforms.
|
||||
type AppConfig struct {
|
||||
CdUID string
|
||||
ProvisionID string
|
||||
CustomHostname string
|
||||
HomeDir string
|
||||
UpstreamProto string
|
||||
Verbose int
|
||||
LogPath string
|
||||
}
|
||||
|
||||
const (
|
||||
defaultHTTPTimeout = 30 * time.Second
|
||||
defaultMaxRetries = 3
|
||||
downloadServerIp = "23.171.240.151"
|
||||
)
|
||||
|
||||
// httpClientWithFallback returns an HTTP client configured with timeout and IPv4 fallback
|
||||
func httpClientWithFallback(timeout time.Duration) *http.Client {
|
||||
return &http.Client{
|
||||
Timeout: timeout,
|
||||
Transport: &http.Transport{
|
||||
// Prefer IPv4 over IPv6
|
||||
DialContext: (&net.Dialer{
|
||||
Timeout: 10 * time.Second,
|
||||
KeepAlive: 30 * time.Second,
|
||||
FallbackDelay: 1 * time.Millisecond, // Very small delay to prefer IPv4
|
||||
}).DialContext,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// doWithRetry performs an HTTP request with retries
|
||||
func doWithRetry(req *http.Request, maxRetries int, ip string) (*http.Response, error) {
|
||||
var lastErr error
|
||||
client := httpClientWithFallback(defaultHTTPTimeout)
|
||||
var ipReq *http.Request
|
||||
if ip != "" {
|
||||
ipReq = req.Clone(req.Context())
|
||||
ipReq.Host = ip
|
||||
ipReq.URL.Host = ip
|
||||
}
|
||||
for attempt := 0; attempt < maxRetries; attempt++ {
|
||||
if attempt > 0 {
|
||||
time.Sleep(time.Second * time.Duration(attempt+1)) // Exponential backoff
|
||||
}
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err == nil {
|
||||
return resp, nil
|
||||
}
|
||||
if ipReq != nil {
|
||||
mainLog.Load().Warn().Err(err).Msgf("dial to %q failed", req.Host)
|
||||
mainLog.Load().Warn().Msgf("fallback to direct IP to download prod version: %q", ip)
|
||||
resp, err = client.Do(ipReq)
|
||||
if err == nil {
|
||||
return resp, nil
|
||||
}
|
||||
}
|
||||
|
||||
lastErr = err
|
||||
mainLog.Load().Debug().Err(err).
|
||||
Str("method", req.Method).
|
||||
Str("url", req.URL.String()).
|
||||
Msgf("HTTP request attempt %d/%d failed", attempt+1, maxRetries)
|
||||
}
|
||||
return nil, fmt.Errorf("failed after %d attempts to %s %s: %v", maxRetries, req.Method, req.URL, lastErr)
|
||||
}
|
||||
|
||||
// Helper for making GET requests with retries
|
||||
func getWithRetry(url string, ip string) (*http.Response, error) {
|
||||
req, err := http.NewRequest(http.MethodGet, url, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return doWithRetry(req, defaultMaxRetries, ip)
|
||||
}
|
||||
339
cmd/cli/log_tail_test.go
Normal file
339
cmd/cli/log_tail_test.go
Normal file
@@ -0,0 +1,339 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"io"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// =============================================================================
|
||||
// logWriter.tailLastLines tests
|
||||
// =============================================================================
|
||||
|
||||
func Test_logWriter_tailLastLines_Empty(t *testing.T) {
|
||||
lw := newLogWriterWithSize(4096)
|
||||
if got := lw.tailLastLines(10); got != nil {
|
||||
t.Fatalf("expected nil for empty buffer, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func Test_logWriter_tailLastLines_ZeroLines(t *testing.T) {
|
||||
lw := newLogWriterWithSize(4096)
|
||||
lw.Write([]byte("line1\nline2\n"))
|
||||
if got := lw.tailLastLines(0); got != nil {
|
||||
t.Fatalf("expected nil for n=0, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func Test_logWriter_tailLastLines_NegativeLines(t *testing.T) {
|
||||
lw := newLogWriterWithSize(4096)
|
||||
lw.Write([]byte("line1\nline2\n"))
|
||||
if got := lw.tailLastLines(-1); got != nil {
|
||||
t.Fatalf("expected nil for n=-1, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func Test_logWriter_tailLastLines_FewerThanN(t *testing.T) {
|
||||
lw := newLogWriterWithSize(4096)
|
||||
lw.Write([]byte("line1\nline2\n"))
|
||||
got := string(lw.tailLastLines(10))
|
||||
want := "line1\nline2\n"
|
||||
if got != want {
|
||||
t.Fatalf("got %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func Test_logWriter_tailLastLines_ExactN(t *testing.T) {
|
||||
lw := newLogWriterWithSize(4096)
|
||||
lw.Write([]byte("line1\nline2\nline3\n"))
|
||||
got := string(lw.tailLastLines(3))
|
||||
want := "line1\nline2\nline3\n"
|
||||
if got != want {
|
||||
t.Fatalf("got %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func Test_logWriter_tailLastLines_MoreThanN(t *testing.T) {
|
||||
lw := newLogWriterWithSize(4096)
|
||||
lw.Write([]byte("line1\nline2\nline3\nline4\nline5\n"))
|
||||
got := string(lw.tailLastLines(2))
|
||||
want := "line4\nline5\n"
|
||||
if got != want {
|
||||
t.Fatalf("got %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func Test_logWriter_tailLastLines_NoTrailingNewline(t *testing.T) {
|
||||
lw := newLogWriterWithSize(4096)
|
||||
lw.Write([]byte("line1\nline2\nline3"))
|
||||
// Without trailing newline, "line3" is a partial line.
|
||||
// Asking for 1 line returns the last newline-terminated line plus the partial.
|
||||
got := string(lw.tailLastLines(1))
|
||||
want := "line2\nline3"
|
||||
if got != want {
|
||||
t.Fatalf("got %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func Test_logWriter_tailLastLines_SingleLineNoNewline(t *testing.T) {
|
||||
lw := newLogWriterWithSize(4096)
|
||||
lw.Write([]byte("only line"))
|
||||
got := string(lw.tailLastLines(5))
|
||||
want := "only line"
|
||||
if got != want {
|
||||
t.Fatalf("got %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func Test_logWriter_tailLastLines_SingleLineWithNewline(t *testing.T) {
|
||||
lw := newLogWriterWithSize(4096)
|
||||
lw.Write([]byte("only line\n"))
|
||||
got := string(lw.tailLastLines(1))
|
||||
want := "only line\n"
|
||||
if got != want {
|
||||
t.Fatalf("got %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// logWriter.Subscribe tests
|
||||
// =============================================================================
|
||||
|
||||
func Test_logWriter_Subscribe_Basic(t *testing.T) {
|
||||
lw := newLogWriterWithSize(4096)
|
||||
ch, unsub := lw.Subscribe()
|
||||
defer unsub()
|
||||
|
||||
msg := []byte("hello world\n")
|
||||
lw.Write(msg)
|
||||
|
||||
select {
|
||||
case got := <-ch:
|
||||
if string(got) != string(msg) {
|
||||
t.Fatalf("got %q, want %q", got, msg)
|
||||
}
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("timed out waiting for subscriber data")
|
||||
}
|
||||
}
|
||||
|
||||
func Test_logWriter_Subscribe_MultipleSubscribers(t *testing.T) {
|
||||
lw := newLogWriterWithSize(4096)
|
||||
ch1, unsub1 := lw.Subscribe()
|
||||
defer unsub1()
|
||||
ch2, unsub2 := lw.Subscribe()
|
||||
defer unsub2()
|
||||
|
||||
msg := []byte("broadcast\n")
|
||||
lw.Write(msg)
|
||||
|
||||
for i, ch := range []<-chan []byte{ch1, ch2} {
|
||||
select {
|
||||
case got := <-ch:
|
||||
if string(got) != string(msg) {
|
||||
t.Fatalf("subscriber %d: got %q, want %q", i, got, msg)
|
||||
}
|
||||
case <-time.After(time.Second):
|
||||
t.Fatalf("subscriber %d: timed out", i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func Test_logWriter_Subscribe_Unsubscribe(t *testing.T) {
|
||||
lw := newLogWriterWithSize(4096)
|
||||
ch, unsub := lw.Subscribe()
|
||||
|
||||
// Verify subscribed.
|
||||
lw.Write([]byte("before unsub\n"))
|
||||
select {
|
||||
case <-ch:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("timed out before unsub")
|
||||
}
|
||||
|
||||
unsub()
|
||||
|
||||
// Channel should be closed after unsub.
|
||||
if _, ok := <-ch; ok {
|
||||
t.Fatal("channel should be closed after unsubscribe")
|
||||
}
|
||||
|
||||
// Verify subscriber list is empty.
|
||||
lw.mu.Lock()
|
||||
count := len(lw.subscribers)
|
||||
lw.mu.Unlock()
|
||||
if count != 0 {
|
||||
t.Fatalf("expected 0 subscribers after unsub, got %d", count)
|
||||
}
|
||||
}
|
||||
|
||||
func Test_logWriter_Subscribe_UnsubscribeIdempotent(t *testing.T) {
|
||||
lw := newLogWriterWithSize(4096)
|
||||
_, unsub := lw.Subscribe()
|
||||
unsub()
|
||||
// Second unsub should not panic.
|
||||
unsub()
|
||||
}
|
||||
|
||||
func Test_logWriter_Subscribe_SlowSubscriberDropped(t *testing.T) {
|
||||
lw := newLogWriterWithSize(4096)
|
||||
ch, unsub := lw.Subscribe()
|
||||
defer unsub()
|
||||
|
||||
// Fill the subscriber channel (buffer size is 256).
|
||||
for i := 0; i < 300; i++ {
|
||||
lw.Write([]byte("msg\n"))
|
||||
}
|
||||
|
||||
// Should have 256 buffered messages, rest dropped.
|
||||
count := 0
|
||||
for {
|
||||
select {
|
||||
case <-ch:
|
||||
count++
|
||||
default:
|
||||
goto done
|
||||
}
|
||||
}
|
||||
done:
|
||||
if count != 256 {
|
||||
t.Fatalf("expected 256 buffered messages, got %d", count)
|
||||
}
|
||||
}
|
||||
|
||||
func Test_logWriter_Subscribe_ConcurrentWriteAndRead(t *testing.T) {
|
||||
lw := newLogWriterWithSize(64 * 1024)
|
||||
ch, unsub := lw.Subscribe()
|
||||
defer unsub()
|
||||
|
||||
const numWrites = 100
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for i := 0; i < numWrites; i++ {
|
||||
lw.Write([]byte("concurrent write\n"))
|
||||
}
|
||||
}()
|
||||
|
||||
received := 0
|
||||
timeout := time.After(5 * time.Second)
|
||||
for received < numWrites {
|
||||
select {
|
||||
case <-ch:
|
||||
received++
|
||||
case <-timeout:
|
||||
t.Fatalf("timed out after receiving %d/%d messages", received, numWrites)
|
||||
}
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// tailFileLastLines tests
|
||||
// =============================================================================
|
||||
|
||||
func writeTempFile(t *testing.T, content string) *os.File {
|
||||
t.Helper()
|
||||
f, err := os.CreateTemp(t.TempDir(), "tail-test-*")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if _, err := f.WriteString(content); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return f
|
||||
}
|
||||
|
||||
func Test_tailFileLastLines_Empty(t *testing.T) {
|
||||
f := writeTempFile(t, "")
|
||||
defer f.Close()
|
||||
if got := tailFileLastLines(f, 10); got != nil {
|
||||
t.Fatalf("expected nil for empty file, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func Test_tailFileLastLines_FewerThanN(t *testing.T) {
|
||||
f := writeTempFile(t, "line1\nline2\n")
|
||||
defer f.Close()
|
||||
got := string(tailFileLastLines(f, 10))
|
||||
want := "line1\nline2\n"
|
||||
if got != want {
|
||||
t.Fatalf("got %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func Test_tailFileLastLines_ExactN(t *testing.T) {
|
||||
f := writeTempFile(t, "a\nb\nc\n")
|
||||
defer f.Close()
|
||||
got := string(tailFileLastLines(f, 3))
|
||||
want := "a\nb\nc\n"
|
||||
if got != want {
|
||||
t.Fatalf("got %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func Test_tailFileLastLines_MoreThanN(t *testing.T) {
|
||||
f := writeTempFile(t, "line1\nline2\nline3\nline4\nline5\n")
|
||||
defer f.Close()
|
||||
got := string(tailFileLastLines(f, 2))
|
||||
want := "line4\nline5\n"
|
||||
if got != want {
|
||||
t.Fatalf("got %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func Test_tailFileLastLines_NoTrailingNewline(t *testing.T) {
|
||||
f := writeTempFile(t, "line1\nline2\nline3")
|
||||
defer f.Close()
|
||||
// Without trailing newline, partial last line comes with the previous line.
|
||||
got := string(tailFileLastLines(f, 1))
|
||||
want := "line2\nline3"
|
||||
if got != want {
|
||||
t.Fatalf("got %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func Test_tailFileLastLines_LargerThanChunk(t *testing.T) {
|
||||
// Build content larger than the 4096 chunk size to exercise multi-chunk reads.
|
||||
var sb strings.Builder
|
||||
for i := 0; i < 200; i++ {
|
||||
sb.WriteString(strings.Repeat("x", 50))
|
||||
sb.WriteByte('\n')
|
||||
}
|
||||
f := writeTempFile(t, sb.String())
|
||||
defer f.Close()
|
||||
got := string(tailFileLastLines(f, 3))
|
||||
lines := strings.Split(strings.TrimRight(got, "\n"), "\n")
|
||||
if len(lines) != 3 {
|
||||
t.Fatalf("expected 3 lines, got %d: %q", len(lines), got)
|
||||
}
|
||||
expectedLine := strings.Repeat("x", 50)
|
||||
for _, line := range lines {
|
||||
if line != expectedLine {
|
||||
t.Fatalf("unexpected line content: %q", line)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func Test_tailFileLastLines_SeeksToEnd(t *testing.T) {
|
||||
f := writeTempFile(t, "line1\nline2\nline3\n")
|
||||
defer f.Close()
|
||||
tailFileLastLines(f, 1)
|
||||
|
||||
// After tailFileLastLines, file position should be at the end.
|
||||
pos, err := f.Seek(0, io.SeekCurrent)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
stat, err := f.Stat()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if pos != stat.Size() {
|
||||
t.Fatalf("expected file position at end (%d), got %d", stat.Size(), pos)
|
||||
}
|
||||
}
|
||||
270
cmd/cli/log_writer.go
Normal file
270
cmd/cli/log_writer.go
Normal file
@@ -0,0 +1,270 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld"
|
||||
)
|
||||
|
||||
const (
|
||||
logWriterSize = 1024 * 1024 * 5 // 5 MB
|
||||
logWriterSmallSize = 1024 * 1024 * 1 // 1 MB
|
||||
logWriterInitialSize = 32 * 1024 // 32 KB
|
||||
logWriterSentInterval = time.Minute
|
||||
logWriterInitEndMarker = "\n\n=== INIT_END ===\n\n"
|
||||
logWriterLogEndMarker = "\n\n=== LOG_END ===\n\n"
|
||||
)
|
||||
|
||||
type logViewResponse struct {
|
||||
Data string `json:"data"`
|
||||
}
|
||||
|
||||
type logSentResponse struct {
|
||||
Size int64 `json:"size"`
|
||||
Error string `json:"error"`
|
||||
}
|
||||
|
||||
type logReader struct {
|
||||
r io.ReadCloser
|
||||
size int64
|
||||
}
|
||||
|
||||
// logSubscriber represents a subscriber to live log output.
|
||||
type logSubscriber struct {
|
||||
ch chan []byte
|
||||
}
|
||||
|
||||
// logWriter is an internal buffer to keep track of runtime log when no logging is enabled.
|
||||
type logWriter struct {
|
||||
mu sync.Mutex
|
||||
buf bytes.Buffer
|
||||
size int
|
||||
subscribers []*logSubscriber
|
||||
}
|
||||
|
||||
// newLogWriter creates an internal log writer.
|
||||
func newLogWriter() *logWriter {
|
||||
return newLogWriterWithSize(logWriterSize)
|
||||
}
|
||||
|
||||
// newSmallLogWriter creates an internal log writer with small buffer size.
|
||||
func newSmallLogWriter() *logWriter {
|
||||
return newLogWriterWithSize(logWriterSmallSize)
|
||||
}
|
||||
|
||||
// newLogWriterWithSize creates an internal log writer with a given buffer size.
|
||||
func newLogWriterWithSize(size int) *logWriter {
|
||||
lw := &logWriter{size: size}
|
||||
return lw
|
||||
}
|
||||
|
||||
// Subscribe returns a channel that receives new log data as it's written,
|
||||
// and an unsubscribe function to clean up when done.
|
||||
func (lw *logWriter) Subscribe() (<-chan []byte, func()) {
|
||||
lw.mu.Lock()
|
||||
defer lw.mu.Unlock()
|
||||
sub := &logSubscriber{ch: make(chan []byte, 256)}
|
||||
lw.subscribers = append(lw.subscribers, sub)
|
||||
unsub := func() {
|
||||
lw.mu.Lock()
|
||||
defer lw.mu.Unlock()
|
||||
for i, s := range lw.subscribers {
|
||||
if s == sub {
|
||||
lw.subscribers = append(lw.subscribers[:i], lw.subscribers[i+1:]...)
|
||||
close(sub.ch)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
return sub.ch, unsub
|
||||
}
|
||||
|
||||
// tailLastLines returns the last n lines from the current buffer.
|
||||
func (lw *logWriter) tailLastLines(n int) []byte {
|
||||
lw.mu.Lock()
|
||||
defer lw.mu.Unlock()
|
||||
data := lw.buf.Bytes()
|
||||
if n <= 0 || len(data) == 0 {
|
||||
return nil
|
||||
}
|
||||
// Find the last n newlines from the end.
|
||||
count := 0
|
||||
pos := len(data)
|
||||
for pos > 0 {
|
||||
pos--
|
||||
if data[pos] == '\n' {
|
||||
count++
|
||||
if count == n+1 {
|
||||
pos++ // move past this newline
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
result := make([]byte, len(data)-pos)
|
||||
copy(result, data[pos:])
|
||||
return result
|
||||
}
|
||||
|
||||
func (lw *logWriter) Write(p []byte) (int, error) {
|
||||
lw.mu.Lock()
|
||||
defer lw.mu.Unlock()
|
||||
|
||||
// Fan-out to subscribers (non-blocking).
|
||||
if len(lw.subscribers) > 0 {
|
||||
cp := make([]byte, len(p))
|
||||
copy(cp, p)
|
||||
for _, sub := range lw.subscribers {
|
||||
select {
|
||||
case sub.ch <- cp:
|
||||
default:
|
||||
// Drop if subscriber is slow to avoid blocking the logger.
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If writing p causes overflows, discard old data.
|
||||
if lw.buf.Len()+len(p) > lw.size {
|
||||
buf := lw.buf.Bytes()
|
||||
haveEndMarker := false
|
||||
// If there's init end marker already, preserve the data til the marker.
|
||||
if idx := bytes.LastIndex(buf, []byte(logWriterInitEndMarker)); idx >= 0 {
|
||||
buf = buf[:idx+len(logWriterInitEndMarker)]
|
||||
haveEndMarker = true
|
||||
} else {
|
||||
// Otherwise, preserve the initial size data.
|
||||
buf = buf[:logWriterInitialSize]
|
||||
if idx := bytes.LastIndex(buf, []byte("\n")); idx != -1 {
|
||||
buf = buf[:idx]
|
||||
}
|
||||
}
|
||||
lw.buf.Reset()
|
||||
lw.buf.Write(buf)
|
||||
if !haveEndMarker {
|
||||
lw.buf.WriteString(logWriterInitEndMarker) // indicate that the log was truncated.
|
||||
}
|
||||
}
|
||||
// If p is bigger than buffer size, truncate p by half until its size is smaller.
|
||||
for len(p)+lw.buf.Len() > lw.size {
|
||||
p = p[len(p)/2:]
|
||||
}
|
||||
return lw.buf.Write(p)
|
||||
}
|
||||
|
||||
// initLogging initializes global logging setup.
|
||||
func (p *prog) initLogging(backup bool) {
|
||||
zerolog.TimeFieldFormat = time.RFC3339 + ".000"
|
||||
logWriters := initLoggingWithBackup(backup)
|
||||
|
||||
// Initializing internal logging after global logging.
|
||||
p.initInternalLogging(logWriters)
|
||||
}
|
||||
|
||||
// initInternalLogging performs internal logging if there's no log enabled.
|
||||
func (p *prog) initInternalLogging(writers []io.Writer) {
|
||||
if !p.needInternalLogging() {
|
||||
return
|
||||
}
|
||||
p.initInternalLogWriterOnce.Do(func() {
|
||||
mainLog.Load().Notice().Msg("internal logging enabled")
|
||||
p.internalLogWriter = newLogWriter()
|
||||
p.internalLogSent = time.Now().Add(-logWriterSentInterval)
|
||||
p.internalWarnLogWriter = newSmallLogWriter()
|
||||
})
|
||||
p.mu.Lock()
|
||||
lw := p.internalLogWriter
|
||||
wlw := p.internalWarnLogWriter
|
||||
p.mu.Unlock()
|
||||
// If ctrld was run without explicit verbose level,
|
||||
// run the internal logging at debug level, so we could
|
||||
// have enough information for troubleshooting.
|
||||
if verbose == 0 {
|
||||
for i := range writers {
|
||||
w := &zerolog.FilteredLevelWriter{
|
||||
Writer: zerolog.LevelWriterAdapter{Writer: writers[i]},
|
||||
Level: zerolog.NoticeLevel,
|
||||
}
|
||||
writers[i] = w
|
||||
}
|
||||
zerolog.SetGlobalLevel(zerolog.DebugLevel)
|
||||
}
|
||||
writers = append(writers, lw)
|
||||
writers = append(writers, &zerolog.FilteredLevelWriter{
|
||||
Writer: zerolog.LevelWriterAdapter{Writer: wlw},
|
||||
Level: zerolog.WarnLevel,
|
||||
})
|
||||
multi := zerolog.MultiLevelWriter(writers...)
|
||||
l := mainLog.Load().Output(multi).With().Logger()
|
||||
mainLog.Store(&l)
|
||||
ctrld.ProxyLogger.Store(&l)
|
||||
}
|
||||
|
||||
// needInternalLogging reports whether prog needs to run internal logging.
|
||||
func (p *prog) needInternalLogging() bool {
|
||||
// Do not run in non-cd mode.
|
||||
if cdUID == "" {
|
||||
return false
|
||||
}
|
||||
// Do not run if there's already log file.
|
||||
if p.cfg.Service.LogPath != "" {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (p *prog) logReader() (*logReader, error) {
|
||||
if p.needInternalLogging() {
|
||||
p.mu.Lock()
|
||||
lw := p.internalLogWriter
|
||||
wlw := p.internalWarnLogWriter
|
||||
p.mu.Unlock()
|
||||
if lw == nil {
|
||||
return nil, errors.New("nil internal log writer")
|
||||
}
|
||||
if wlw == nil {
|
||||
return nil, errors.New("nil internal warn log writer")
|
||||
}
|
||||
// Normal log content.
|
||||
lw.mu.Lock()
|
||||
lwReader := bytes.NewReader(lw.buf.Bytes())
|
||||
lwSize := lw.buf.Len()
|
||||
lw.mu.Unlock()
|
||||
// Warn log content.
|
||||
wlw.mu.Lock()
|
||||
wlwReader := bytes.NewReader(wlw.buf.Bytes())
|
||||
wlwSize := wlw.buf.Len()
|
||||
wlw.mu.Unlock()
|
||||
reader := io.MultiReader(lwReader, bytes.NewReader([]byte(logWriterLogEndMarker)), wlwReader)
|
||||
lr := &logReader{r: io.NopCloser(reader)}
|
||||
lr.size = int64(lwSize + wlwSize)
|
||||
if lr.size == 0 {
|
||||
return nil, errors.New("internal log is empty")
|
||||
}
|
||||
return lr, nil
|
||||
}
|
||||
if p.cfg.Service.LogPath == "" {
|
||||
return &logReader{r: io.NopCloser(strings.NewReader(""))}, nil
|
||||
}
|
||||
f, err := os.Open(normalizeLogFilePath(p.cfg.Service.LogPath))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
lr := &logReader{r: f}
|
||||
if st, err := f.Stat(); err == nil {
|
||||
lr.size = st.Size()
|
||||
} else {
|
||||
return nil, fmt.Errorf("f.Stat: %w", err)
|
||||
}
|
||||
if lr.size == 0 {
|
||||
return nil, errors.New("log file is empty")
|
||||
}
|
||||
return lr, nil
|
||||
}
|
||||
85
cmd/cli/log_writer_test.go
Normal file
85
cmd/cli/log_writer_test.go
Normal file
@@ -0,0 +1,85 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func Test_logWriter_Write(t *testing.T) {
|
||||
size := 64 * 1024
|
||||
lw := &logWriter{size: size}
|
||||
lw.buf.Grow(lw.size)
|
||||
data := strings.Repeat("A", size)
|
||||
lw.Write([]byte(data))
|
||||
if lw.buf.String() != data {
|
||||
t.Fatalf("unexpected buf content: %v", lw.buf.String())
|
||||
}
|
||||
newData := "B"
|
||||
halfData := strings.Repeat("A", len(data)/2) + logWriterInitEndMarker
|
||||
lw.Write([]byte(newData))
|
||||
if lw.buf.String() != halfData+newData {
|
||||
t.Fatalf("unexpected new buf content: %v", lw.buf.String())
|
||||
}
|
||||
|
||||
bigData := strings.Repeat("B", 256*1024)
|
||||
expected := halfData + strings.Repeat("B", 16*1024)
|
||||
lw.Write([]byte(bigData))
|
||||
if lw.buf.String() != expected {
|
||||
t.Fatalf("unexpected big buf content: %v", lw.buf.String())
|
||||
}
|
||||
}
|
||||
|
||||
func Test_logWriter_ConcurrentWrite(t *testing.T) {
|
||||
size := 64 * 1024
|
||||
lw := &logWriter{size: size}
|
||||
n := 10
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(n)
|
||||
for i := 0; i < n; i++ {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
lw.Write([]byte(strings.Repeat("A", i)))
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
if lw.buf.Len() > lw.size {
|
||||
t.Fatalf("unexpected buf size: %v, content: %q", lw.buf.Len(), lw.buf.String())
|
||||
}
|
||||
}
|
||||
|
||||
func Test_logWriter_MarkerInitEnd(t *testing.T) {
|
||||
size := 64 * 1024
|
||||
lw := &logWriter{size: size}
|
||||
lw.buf.Grow(lw.size)
|
||||
|
||||
paddingSize := 10
|
||||
// Writing half of the size, minus len(end marker) and padding size.
|
||||
dataSize := size/2 - len(logWriterInitEndMarker) - paddingSize
|
||||
data := strings.Repeat("A", dataSize)
|
||||
// Inserting newline for making partial init data
|
||||
data += "\n"
|
||||
// Filling left over buffer to make the log full.
|
||||
// The data length: len(end marker) + padding size - 1 (for newline above) + size/2
|
||||
data += strings.Repeat("A", len(logWriterInitEndMarker)+paddingSize-1+(size/2))
|
||||
lw.Write([]byte(data))
|
||||
if lw.buf.String() != data {
|
||||
t.Fatalf("unexpected buf content: %v", lw.buf.String())
|
||||
}
|
||||
lw.Write([]byte("B"))
|
||||
lw.Write([]byte(strings.Repeat("B", 256*1024)))
|
||||
firstIdx := strings.Index(lw.buf.String(), logWriterInitEndMarker)
|
||||
lastIdx := strings.LastIndex(lw.buf.String(), logWriterInitEndMarker)
|
||||
// Check if init end marker present.
|
||||
if firstIdx == -1 || lastIdx == -1 {
|
||||
t.Fatalf("missing init end marker: %s", lw.buf.String())
|
||||
}
|
||||
// Check if init end marker appears only once.
|
||||
if firstIdx != lastIdx {
|
||||
t.Fatalf("log init end marker appears more than once: %s", lw.buf.String())
|
||||
}
|
||||
// Ensure that we have the correct init log data.
|
||||
if !strings.Contains(lw.buf.String(), strings.Repeat("A", dataSize)+logWriterInitEndMarker) {
|
||||
t.Fatalf("unexpected log content: %s", lw.buf.String())
|
||||
}
|
||||
}
|
||||
145
cmd/cli/loop.go
Normal file
145
cmd/cli/loop.go
Normal file
@@ -0,0 +1,145 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld"
|
||||
)
|
||||
|
||||
const (
|
||||
loopTestDomain = ".test"
|
||||
loopTestQtype = dns.TypeTXT
|
||||
)
|
||||
|
||||
// newLoopGuard returns new loopGuard.
|
||||
func newLoopGuard() *loopGuard {
|
||||
return &loopGuard{inflight: make(map[string]struct{})}
|
||||
}
|
||||
|
||||
// loopGuard guards against DNS loop, ensuring only one query
|
||||
// for a given domain is processed at a time.
|
||||
type loopGuard struct {
|
||||
mu sync.Mutex
|
||||
inflight map[string]struct{}
|
||||
}
|
||||
|
||||
// TryLock marks the domain as being processed.
|
||||
func (lg *loopGuard) TryLock(domain string) bool {
|
||||
lg.mu.Lock()
|
||||
defer lg.mu.Unlock()
|
||||
if _, inflight := lg.inflight[domain]; !inflight {
|
||||
lg.inflight[domain] = struct{}{}
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// Unlock marks the domain as being done.
|
||||
func (lg *loopGuard) Unlock(domain string) {
|
||||
lg.mu.Lock()
|
||||
defer lg.mu.Unlock()
|
||||
delete(lg.inflight, domain)
|
||||
}
|
||||
|
||||
// isLoop reports whether the given upstream config is detected as having DNS loop.
|
||||
func (p *prog) isLoop(uc *ctrld.UpstreamConfig) bool {
|
||||
p.loopMu.Lock()
|
||||
defer p.loopMu.Unlock()
|
||||
return p.loop[uc.UID()]
|
||||
}
|
||||
|
||||
// detectLoop checks if the given DNS message is initialized sent by ctrld.
|
||||
// If yes, marking the corresponding upstream as loop, prevent infinite DNS
|
||||
// forwarding loop.
|
||||
//
|
||||
// See p.checkDnsLoop for more details how it works.
|
||||
func (p *prog) detectLoop(msg *dns.Msg) {
|
||||
if len(msg.Question) != 1 {
|
||||
return
|
||||
}
|
||||
q := msg.Question[0]
|
||||
if q.Qtype != loopTestQtype {
|
||||
return
|
||||
}
|
||||
unFQDNname := strings.TrimSuffix(q.Name, ".")
|
||||
uid := strings.TrimSuffix(unFQDNname, loopTestDomain)
|
||||
p.loopMu.Lock()
|
||||
if _, loop := p.loop[uid]; loop {
|
||||
p.loop[uid] = loop
|
||||
}
|
||||
p.loopMu.Unlock()
|
||||
}
|
||||
|
||||
// checkDnsLoop sends a message to check if there's any DNS forwarding loop
|
||||
// with all the upstreams. The way it works based on dnsmasq --dns-loop-detect.
|
||||
//
|
||||
// - Generating a TXT test query and sending it to all upstream.
|
||||
// - The test query is formed by upstream UID and test domain: <uid>.test
|
||||
// - If the test query returns to ctrld, mark the corresponding upstream as loop (see p.detectLoop).
|
||||
//
|
||||
// See: https://thekelleys.org.uk/dnsmasq/docs/dnsmasq-man.html
|
||||
func (p *prog) checkDnsLoop() {
|
||||
mainLog.Load().Debug().Msg("start checking DNS loop")
|
||||
upstream := make(map[string]*ctrld.UpstreamConfig)
|
||||
p.loopMu.Lock()
|
||||
for n, uc := range p.cfg.Upstream {
|
||||
if p.um.isDown("upstream." + n) {
|
||||
continue
|
||||
}
|
||||
// Do not send test query to external upstream.
|
||||
if !canBeLocalUpstream(uc.Domain) {
|
||||
mainLog.Load().Debug().Msgf("skipping external: upstream.%s", n)
|
||||
continue
|
||||
}
|
||||
uid := uc.UID()
|
||||
p.loop[uid] = false
|
||||
upstream[uid] = uc
|
||||
}
|
||||
p.loopMu.Unlock()
|
||||
|
||||
for uid := range p.loop {
|
||||
msg := loopTestMsg(uid)
|
||||
uc := upstream[uid]
|
||||
// Skipping upstream which is being marked as down.
|
||||
if uc == nil {
|
||||
continue
|
||||
}
|
||||
resolver, err := ctrld.NewResolver(uc)
|
||||
if err != nil {
|
||||
mainLog.Load().Warn().Err(err).Msgf("could not perform loop check for upstream: %q, endpoint: %q", uc.Name, uc.Endpoint)
|
||||
continue
|
||||
}
|
||||
if _, err := resolver.Resolve(context.Background(), msg); err != nil {
|
||||
mainLog.Load().Warn().Err(err).Msgf("could not send DNS loop check query for upstream: %q, endpoint: %q", uc.Name, uc.Endpoint)
|
||||
}
|
||||
}
|
||||
mainLog.Load().Debug().Msg("end checking DNS loop")
|
||||
}
|
||||
|
||||
// checkDnsLoopTicker performs p.checkDnsLoop every minute.
|
||||
func (p *prog) checkDnsLoopTicker(ctx context.Context) {
|
||||
timer := time.NewTicker(time.Minute)
|
||||
defer timer.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-p.stopCh:
|
||||
return
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-timer.C:
|
||||
p.checkDnsLoop()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// loopTestMsg generates DNS message for checking loop.
|
||||
func loopTestMsg(uid string) *dns.Msg {
|
||||
msg := new(dns.Msg)
|
||||
msg.SetQuestion(dns.Fqdn(uid+loopTestDomain), loopTestQtype)
|
||||
return msg
|
||||
}
|
||||
42
cmd/cli/loop_test.go
Normal file
42
cmd/cli/loop_test.go
Normal file
@@ -0,0 +1,42 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func Test_loopGuard(t *testing.T) {
|
||||
lg := newLoopGuard()
|
||||
key := "foo"
|
||||
|
||||
var i atomic.Int64
|
||||
var started atomic.Int64
|
||||
n := 1000
|
||||
do := func() {
|
||||
locked := lg.TryLock(key)
|
||||
defer lg.Unlock(key)
|
||||
started.Add(1)
|
||||
for started.Load() < 2 {
|
||||
// Wait until at least 2 goroutines started, otherwise, on system with heavy load,
|
||||
// or having only 1 CPU, all goroutines can be scheduled to run consequently.
|
||||
}
|
||||
if locked {
|
||||
i.Add(1)
|
||||
}
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(n)
|
||||
for i := 0; i < n; i++ {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
do()
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
if i.Load() == int64(n) {
|
||||
t.Fatalf("i must not be increased %d times", n)
|
||||
}
|
||||
}
|
||||
228
cmd/cli/main.go
Normal file
228
cmd/cli/main.go
Normal file
@@ -0,0 +1,228 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"encoding/hex"
|
||||
"io"
|
||||
"net"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/kardianos/service"
|
||||
"github.com/rs/zerolog"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld"
|
||||
)
|
||||
|
||||
var (
|
||||
configPath string
|
||||
configBase64 string
|
||||
daemon bool
|
||||
listenAddress string
|
||||
primaryUpstream string
|
||||
secondaryUpstream string
|
||||
domains []string
|
||||
logPath string
|
||||
homedir string
|
||||
cacheSize int
|
||||
cfg ctrld.Config
|
||||
verbose int
|
||||
silent bool
|
||||
cdUID string
|
||||
cdOrg string
|
||||
customHostname string
|
||||
cdDev bool
|
||||
iface string
|
||||
ifaceStartStop string
|
||||
nextdns string
|
||||
cdUpstreamProto string
|
||||
deactivationPin int64
|
||||
skipSelfChecks bool
|
||||
cleanup bool
|
||||
startOnly bool
|
||||
rfc1918 bool
|
||||
interceptMode string // "", "dns", or "hard" — set via --intercept-mode flag or config
|
||||
dnsIntercept bool // derived: interceptMode == "dns" || interceptMode == "hard"
|
||||
hardIntercept bool // derived: interceptMode == "hard"
|
||||
|
||||
mainLog atomic.Pointer[zerolog.Logger]
|
||||
consoleWriter zerolog.ConsoleWriter
|
||||
noConfigStart bool
|
||||
)
|
||||
|
||||
const (
|
||||
cdUidFlagName = "cd"
|
||||
cdOrgFlagName = "cd-org"
|
||||
customHostnameFlagName = "custom-hostname"
|
||||
nextdnsFlagName = "nextdns"
|
||||
)
|
||||
|
||||
func init() {
|
||||
l := zerolog.New(io.Discard)
|
||||
mainLog.Store(&l)
|
||||
}
|
||||
|
||||
func Main() {
|
||||
// Fast path for pf interception probe subprocess. This runs before cobra
|
||||
// initialization to minimize startup time. The parent process spawns us with
|
||||
// "pf-probe-send <host> <hex-dns-packet>" and a non-_ctrld GID so pf
|
||||
// intercepts the DNS query. If pf rdr is working, the query reaches ctrld's
|
||||
// listener; if not, it goes to the real DNS server and ctrld detects the miss.
|
||||
if len(os.Args) >= 4 && os.Args[1] == "pf-probe-send" {
|
||||
pfProbeSend(os.Args[2], os.Args[3])
|
||||
return
|
||||
}
|
||||
|
||||
ctrld.InitConfig(v, "ctrld")
|
||||
initCLI()
|
||||
if err := rootCmd.Execute(); err != nil {
|
||||
mainLog.Load().Error().Msg(err.Error())
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
func normalizeLogFilePath(logFilePath string) string {
|
||||
if logFilePath == "" || filepath.IsAbs(logFilePath) || service.Interactive() {
|
||||
return logFilePath
|
||||
}
|
||||
if homedir != "" {
|
||||
return filepath.Join(homedir, logFilePath)
|
||||
}
|
||||
dir, _ := userHomeDir()
|
||||
if dir == "" {
|
||||
return logFilePath
|
||||
}
|
||||
return filepath.Join(dir, logFilePath)
|
||||
}
|
||||
|
||||
// initConsoleLogging initializes console logging, then storing to mainLog.
|
||||
func initConsoleLogging() {
|
||||
consoleWriter = zerolog.NewConsoleWriter(func(w *zerolog.ConsoleWriter) {
|
||||
w.TimeFormat = time.StampMilli
|
||||
})
|
||||
multi := zerolog.MultiLevelWriter(consoleWriter)
|
||||
l := mainLog.Load().Output(multi).With().Timestamp().Logger()
|
||||
mainLog.Store(&l)
|
||||
|
||||
switch {
|
||||
case silent:
|
||||
zerolog.SetGlobalLevel(zerolog.NoLevel)
|
||||
case verbose == 1:
|
||||
ctrld.ProxyLogger.Store(&l)
|
||||
zerolog.SetGlobalLevel(zerolog.InfoLevel)
|
||||
case verbose > 1:
|
||||
ctrld.ProxyLogger.Store(&l)
|
||||
zerolog.SetGlobalLevel(zerolog.DebugLevel)
|
||||
default:
|
||||
zerolog.SetGlobalLevel(zerolog.NoticeLevel)
|
||||
}
|
||||
}
|
||||
|
||||
// initInteractiveLogging is like initLogging, but the ProxyLogger is discarded
|
||||
// to be used for all interactive commands.
|
||||
//
|
||||
// Current log file config will also be ignored.
|
||||
func initInteractiveLogging() {
|
||||
old := cfg.Service.LogPath
|
||||
cfg.Service.LogPath = ""
|
||||
zerolog.TimeFieldFormat = time.RFC3339 + ".000"
|
||||
initLoggingWithBackup(false)
|
||||
cfg.Service.LogPath = old
|
||||
l := zerolog.New(io.Discard)
|
||||
ctrld.ProxyLogger.Store(&l)
|
||||
}
|
||||
|
||||
// initLoggingWithBackup initializes log setup base on current config.
|
||||
// If doBackup is true, backup old log file with ".1" suffix.
|
||||
//
|
||||
// This is only used in runCmd for special handling in case of logging config
|
||||
// change in cd mode. Without special reason, the caller should use initLogging
|
||||
// wrapper instead of calling this function directly.
|
||||
func initLoggingWithBackup(doBackup bool) []io.Writer {
|
||||
var writers []io.Writer
|
||||
if logFilePath := normalizeLogFilePath(cfg.Service.LogPath); logFilePath != "" {
|
||||
// Create parent directory if necessary.
|
||||
if err := os.MkdirAll(filepath.Dir(logFilePath), 0750); err != nil {
|
||||
mainLog.Load().Error().Msgf("failed to create log path: %v", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
// Default open log file in append mode.
|
||||
flags := os.O_CREATE | os.O_RDWR | os.O_APPEND
|
||||
if doBackup {
|
||||
// Backup old log file with .1 suffix.
|
||||
if err := os.Rename(logFilePath, logFilePath+oldLogSuffix); err != nil && !os.IsNotExist(err) {
|
||||
mainLog.Load().Error().Msgf("could not backup old log file: %v", err)
|
||||
} else {
|
||||
// Backup was created, set flags for truncating old log file.
|
||||
flags = os.O_CREATE | os.O_RDWR
|
||||
}
|
||||
}
|
||||
logFile, err := openLogFile(logFilePath, flags)
|
||||
if err != nil {
|
||||
mainLog.Load().Error().Msgf("failed to create log file: %v", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
writers = append(writers, logFile)
|
||||
}
|
||||
writers = append(writers, consoleWriter)
|
||||
multi := zerolog.MultiLevelWriter(writers...)
|
||||
l := mainLog.Load().Output(multi).With().Logger()
|
||||
mainLog.Store(&l)
|
||||
// TODO: find a better way.
|
||||
ctrld.ProxyLogger.Store(&l)
|
||||
|
||||
zerolog.SetGlobalLevel(zerolog.NoticeLevel)
|
||||
logLevel := cfg.Service.LogLevel
|
||||
switch {
|
||||
case silent:
|
||||
zerolog.SetGlobalLevel(zerolog.NoLevel)
|
||||
return writers
|
||||
case verbose == 1:
|
||||
logLevel = "info"
|
||||
case verbose > 1:
|
||||
logLevel = "debug"
|
||||
}
|
||||
if logLevel == "" {
|
||||
return writers
|
||||
}
|
||||
level, err := zerolog.ParseLevel(logLevel)
|
||||
if err != nil {
|
||||
mainLog.Load().Warn().Err(err).Msg("could not set log level")
|
||||
return writers
|
||||
}
|
||||
zerolog.SetGlobalLevel(level)
|
||||
return writers
|
||||
}
|
||||
|
||||
func initCache() {
|
||||
if !cfg.Service.CacheEnable {
|
||||
return
|
||||
}
|
||||
if cfg.Service.CacheSize == 0 {
|
||||
cfg.Service.CacheSize = 4096
|
||||
}
|
||||
}
|
||||
|
||||
// pfProbeSend is a minimal subprocess that sends a pre-built DNS query packet
|
||||
// to the specified host on port 53. It's invoked by probePFIntercept() with a
|
||||
// non-_ctrld GID so pf interception applies to the query.
|
||||
//
|
||||
// Usage: ctrld pf-probe-send <host> <hex-encoded-dns-packet>
|
||||
func pfProbeSend(host, hexPacket string) {
|
||||
packet, err := hex.DecodeString(hexPacket)
|
||||
if err != nil {
|
||||
os.Exit(1)
|
||||
}
|
||||
conn, err := net.DialTimeout("udp", net.JoinHostPort(host, "53"), time.Second)
|
||||
if err != nil {
|
||||
os.Exit(1)
|
||||
}
|
||||
defer conn.Close()
|
||||
conn.SetDeadline(time.Now().Add(time.Second))
|
||||
_, _ = conn.Write(packet)
|
||||
// Read response (don't care about result, just need the send to happen)
|
||||
buf := make([]byte, 512)
|
||||
_, _ = conn.Read(buf)
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package main
|
||||
package cli
|
||||
|
||||
import (
|
||||
"os"
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
var logOutput strings.Builder
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
mainLog = zerolog.New(&logOutput)
|
||||
l := zerolog.New(&logOutput)
|
||||
mainLog.Store(&l)
|
||||
os.Exit(m.Run())
|
||||
}
|
||||
150
cmd/cli/metrics.go
Normal file
150
cmd/cli/metrics.go
Normal file
@@ -0,0 +1,150 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net"
|
||||
"net/http"
|
||||
"runtime"
|
||||
"time"
|
||||
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"github.com/prometheus/client_golang/prometheus/collectors"
|
||||
"github.com/prometheus/client_golang/prometheus/promhttp"
|
||||
"github.com/prometheus/prom2json"
|
||||
)
|
||||
|
||||
// metricsServer represents a server to expose Prometheus metrics via HTTP.
|
||||
type metricsServer struct {
|
||||
server *http.Server
|
||||
mux *http.ServeMux
|
||||
reg *prometheus.Registry
|
||||
addr string
|
||||
started bool
|
||||
}
|
||||
|
||||
// newMetricsServer returns new metrics server.
|
||||
func newMetricsServer(addr string, reg *prometheus.Registry) (*metricsServer, error) {
|
||||
mux := http.NewServeMux()
|
||||
ms := &metricsServer{
|
||||
server: &http.Server{Handler: mux},
|
||||
mux: mux,
|
||||
reg: reg,
|
||||
}
|
||||
ms.addr = addr
|
||||
ms.registerMetricsServerHandler()
|
||||
return ms, nil
|
||||
}
|
||||
|
||||
// register adds handlers for given pattern.
|
||||
func (ms *metricsServer) register(pattern string, handler http.Handler) {
|
||||
ms.mux.Handle(pattern, handler)
|
||||
}
|
||||
|
||||
// registerMetricsServerHandler adds handlers for metrics server.
|
||||
func (ms *metricsServer) registerMetricsServerHandler() {
|
||||
ms.register("/metrics", promhttp.HandlerFor(
|
||||
ms.reg,
|
||||
promhttp.HandlerOpts{
|
||||
EnableOpenMetrics: true,
|
||||
Timeout: 10 * time.Second,
|
||||
},
|
||||
))
|
||||
ms.register("/metrics/json", jsonResponse(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
g := prometheus.ToTransactionalGatherer(ms.reg)
|
||||
mfs, done, err := g.Gather()
|
||||
defer done()
|
||||
if err != nil {
|
||||
msg := "could not gather metrics"
|
||||
mainLog.Load().Warn().Err(err).Msg(msg)
|
||||
http.Error(w, msg, http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
result := make([]*prom2json.Family, 0, len(mfs))
|
||||
for _, mf := range mfs {
|
||||
result = append(result, prom2json.NewFamily(mf))
|
||||
}
|
||||
if err := json.NewEncoder(w).Encode(result); err != nil {
|
||||
msg := "could not marshal metrics result"
|
||||
mainLog.Load().Warn().Err(err).Msg(msg)
|
||||
http.Error(w, msg, http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
})))
|
||||
}
|
||||
|
||||
// start runs the metricsServer.
|
||||
func (ms *metricsServer) start() error {
|
||||
listener, err := net.Listen("tcp", ms.addr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
go ms.server.Serve(listener)
|
||||
ms.started = true
|
||||
return nil
|
||||
}
|
||||
|
||||
// stop shutdowns the metricsServer within 2 seconds timeout.
|
||||
func (ms *metricsServer) stop() error {
|
||||
if !ms.started {
|
||||
return nil
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*1)
|
||||
defer cancel()
|
||||
return ms.server.Shutdown(ctx)
|
||||
}
|
||||
|
||||
// runMetricsServer initializes metrics stats and runs the metrics server if enabled.
|
||||
func (p *prog) runMetricsServer(ctx context.Context, reloadCh chan struct{}) {
|
||||
if !p.metricsEnabled() {
|
||||
return
|
||||
}
|
||||
|
||||
// Reset all stats.
|
||||
statsVersion.Reset()
|
||||
statsQueriesCount.Reset()
|
||||
statsClientQueriesCount.Reset()
|
||||
|
||||
reg := prometheus.NewRegistry()
|
||||
// Register queries count stats if enabled.
|
||||
if p.metricsQueryStats.Load() {
|
||||
reg.MustRegister(statsQueriesCount)
|
||||
reg.MustRegister(statsClientQueriesCount)
|
||||
}
|
||||
|
||||
addr := p.cfg.Service.MetricsListener
|
||||
ms, err := newMetricsServer(addr, reg)
|
||||
if err != nil {
|
||||
mainLog.Load().Warn().Err(err).Msg("could not create new metrics server")
|
||||
return
|
||||
}
|
||||
// Only start listener address if defined.
|
||||
if addr != "" {
|
||||
// Go runtime stats.
|
||||
reg.MustRegister(collectors.NewBuildInfoCollector())
|
||||
reg.MustRegister(collectors.NewGoCollector(
|
||||
collectors.WithGoCollectorRuntimeMetrics(collectors.MetricsAll),
|
||||
))
|
||||
// ctrld stats.
|
||||
reg.MustRegister(statsVersion)
|
||||
statsVersion.WithLabelValues(commit, runtime.Version(), curVersion()).Inc()
|
||||
reg.MustRegister(statsTimeStart)
|
||||
statsTimeStart.Set(float64(time.Now().Unix()))
|
||||
mainLog.Load().Debug().Msgf("starting metrics server on: %s", addr)
|
||||
if err := ms.start(); err != nil {
|
||||
mainLog.Load().Warn().Err(err).Msg("could not start metrics server")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
select {
|
||||
case <-p.stopCh:
|
||||
case <-ctx.Done():
|
||||
case <-reloadCh:
|
||||
}
|
||||
|
||||
if err := ms.stop(); err != nil {
|
||||
mainLog.Load().Warn().Err(err).Msg("could not stop metrics server")
|
||||
return
|
||||
}
|
||||
}
|
||||
76
cmd/cli/net_darwin.go
Normal file
76
cmd/cli/net_darwin.go
Normal file
@@ -0,0 +1,76 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"io"
|
||||
"net"
|
||||
"os/exec"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func patchNetIfaceName(iface *net.Interface) (bool, error) {
|
||||
b, err := exec.Command("networksetup", "-listnetworkserviceorder").Output()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
patched := false
|
||||
if name := networkServiceName(iface.Name, bytes.NewReader(b)); name != "" {
|
||||
patched = true
|
||||
iface.Name = name
|
||||
}
|
||||
return patched, nil
|
||||
}
|
||||
|
||||
func networkServiceName(ifaceName string, r io.Reader) string {
|
||||
scanner := bufio.NewScanner(r)
|
||||
prevLine := ""
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if strings.Contains(line, "*") {
|
||||
// Network services is disabled.
|
||||
continue
|
||||
}
|
||||
if !strings.Contains(line, "Device: "+ifaceName) {
|
||||
prevLine = line
|
||||
continue
|
||||
}
|
||||
parts := strings.SplitN(prevLine, " ", 2)
|
||||
if len(parts) == 2 {
|
||||
return strings.TrimSpace(parts[1])
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// validInterface reports whether the *net.Interface is a valid one.
|
||||
func validInterface(iface *net.Interface, validIfacesMap map[string]struct{}) bool {
|
||||
_, ok := validIfacesMap[iface.Name]
|
||||
return ok
|
||||
}
|
||||
|
||||
// validInterfacesMap returns a set of all valid hardware ports.
|
||||
func validInterfacesMap() map[string]struct{} {
|
||||
b, err := exec.Command("networksetup", "-listallhardwareports").Output()
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
return parseListAllHardwarePorts(bytes.NewReader(b))
|
||||
}
|
||||
|
||||
// parseListAllHardwarePorts parses output of "networksetup -listallhardwareports"
|
||||
// and returns map presents all hardware ports.
|
||||
func parseListAllHardwarePorts(r io.Reader) map[string]struct{} {
|
||||
m := make(map[string]struct{})
|
||||
scanner := bufio.NewScanner(r)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
after, ok := strings.CutPrefix(line, "Device: ")
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
m[after] = struct{}{}
|
||||
}
|
||||
return m
|
||||
}
|
||||
52
cmd/cli/net_linux.go
Normal file
52
cmd/cli/net_linux.go
Normal file
@@ -0,0 +1,52 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"tailscale.com/net/netmon"
|
||||
)
|
||||
|
||||
func patchNetIfaceName(iface *net.Interface) (bool, error) { return true, nil }
|
||||
|
||||
// validInterface reports whether the *net.Interface is a valid one.
|
||||
// Only non-virtual interfaces are considered valid.
|
||||
func validInterface(iface *net.Interface, validIfacesMap map[string]struct{}) bool {
|
||||
_, ok := validIfacesMap[iface.Name]
|
||||
return ok
|
||||
}
|
||||
|
||||
// validInterfacesMap returns a set containing non virtual interfaces.
|
||||
func validInterfacesMap() map[string]struct{} {
|
||||
m := make(map[string]struct{})
|
||||
vis := virtualInterfaces()
|
||||
netmon.ForeachInterface(func(i netmon.Interface, prefixes []netip.Prefix) {
|
||||
if _, existed := vis[i.Name]; existed {
|
||||
return
|
||||
}
|
||||
m[i.Name] = struct{}{}
|
||||
})
|
||||
// Fallback to default route interface if found nothing.
|
||||
if len(m) == 0 {
|
||||
defaultRoute, err := netmon.DefaultRoute()
|
||||
if err != nil {
|
||||
return m
|
||||
}
|
||||
m[defaultRoute.InterfaceName] = struct{}{}
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
// virtualInterfaces returns a map of virtual interfaces on current machine.
|
||||
func virtualInterfaces() map[string]struct{} {
|
||||
s := make(map[string]struct{})
|
||||
entries, _ := os.ReadDir("/sys/devices/virtual/net")
|
||||
for _, entry := range entries {
|
||||
if entry.IsDir() {
|
||||
s[strings.TrimSpace(entry.Name())] = struct{}{}
|
||||
}
|
||||
}
|
||||
return s
|
||||
}
|
||||
22
cmd/cli/net_others.go
Normal file
22
cmd/cli/net_others.go
Normal file
@@ -0,0 +1,22 @@
|
||||
//go:build !darwin && !windows && !linux
|
||||
|
||||
package cli
|
||||
|
||||
import (
|
||||
"net"
|
||||
|
||||
"tailscale.com/net/netmon"
|
||||
)
|
||||
|
||||
func patchNetIfaceName(iface *net.Interface) (bool, error) { return true, nil }
|
||||
|
||||
func validInterface(iface *net.Interface, validIfacesMap map[string]struct{}) bool { return true }
|
||||
|
||||
// validInterfacesMap returns a set containing only default route interfaces.
|
||||
func validInterfacesMap() map[string]struct{} {
|
||||
defaultRoute, err := netmon.DefaultRoute()
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
return map[string]struct{}{defaultRoute.InterfaceName: {}}
|
||||
}
|
||||
93
cmd/cli/net_windows.go
Normal file
93
cmd/cli/net_windows.go
Normal file
@@ -0,0 +1,93 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"os"
|
||||
|
||||
"github.com/microsoft/wmi/pkg/base/host"
|
||||
"github.com/microsoft/wmi/pkg/base/instance"
|
||||
"github.com/microsoft/wmi/pkg/base/query"
|
||||
"github.com/microsoft/wmi/pkg/constant"
|
||||
"github.com/microsoft/wmi/pkg/hardware/network/netadapter"
|
||||
)
|
||||
|
||||
func patchNetIfaceName(iface *net.Interface) (bool, error) {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// validInterface reports whether the *net.Interface is a valid one.
|
||||
// On Windows, only physical interfaces are considered valid.
|
||||
func validInterface(iface *net.Interface, validIfacesMap map[string]struct{}) bool {
|
||||
_, ok := validIfacesMap[iface.Name]
|
||||
return ok
|
||||
}
|
||||
|
||||
// validInterfacesMap returns a set of all physical interfaces.
|
||||
func validInterfacesMap() map[string]struct{} {
|
||||
m := make(map[string]struct{})
|
||||
for _, ifaceName := range validInterfaces() {
|
||||
m[ifaceName] = struct{}{}
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
// validInterfaces returns a list of all physical interfaces.
|
||||
func validInterfaces() []string {
|
||||
log.SetOutput(io.Discard)
|
||||
defer log.SetOutput(os.Stderr)
|
||||
whost := host.NewWmiLocalHost()
|
||||
q := query.NewWmiQuery("MSFT_NetAdapter")
|
||||
instances, err := instance.GetWmiInstancesFromHost(whost, string(constant.StadardCimV2), q)
|
||||
if instances != nil {
|
||||
defer instances.Close()
|
||||
}
|
||||
if err != nil {
|
||||
mainLog.Load().Warn().Err(err).Msg("failed to get wmi network adapter")
|
||||
return nil
|
||||
}
|
||||
var adapters []string
|
||||
for _, i := range instances {
|
||||
adapter, err := netadapter.NewNetworkAdapter(i)
|
||||
if err != nil {
|
||||
mainLog.Load().Warn().Err(err).Msg("failed to get network adapter")
|
||||
continue
|
||||
}
|
||||
|
||||
name, err := adapter.GetPropertyName()
|
||||
if err != nil {
|
||||
mainLog.Load().Warn().Err(err).Msg("failed to get interface name")
|
||||
continue
|
||||
}
|
||||
|
||||
// From: https://learn.microsoft.com/en-us/previous-versions/windows/desktop/legacy/hh968170(v=vs.85)
|
||||
//
|
||||
// "Indicates if a connector is present on the network adapter. This value is set to TRUE
|
||||
// if this is a physical adapter or FALSE if this is not a physical adapter."
|
||||
physical, err := adapter.GetPropertyConnectorPresent()
|
||||
if err != nil {
|
||||
mainLog.Load().Debug().Str("method", "validInterfaces").Str("interface", name).Msg("failed to get network adapter connector present property")
|
||||
continue
|
||||
}
|
||||
if !physical {
|
||||
mainLog.Load().Debug().Str("method", "validInterfaces").Str("interface", name).Msg("skipping non-physical adapter")
|
||||
continue
|
||||
}
|
||||
|
||||
// Check if it's a hardware interface. Checking only for connector present is not enough
|
||||
// because some interfaces are not physical but have a connector.
|
||||
hardware, err := adapter.GetPropertyHardwareInterface()
|
||||
if err != nil {
|
||||
mainLog.Load().Debug().Str("method", "validInterfaces").Str("interface", name).Msg("failed to get network adapter hardware interface property")
|
||||
continue
|
||||
}
|
||||
if !hardware {
|
||||
mainLog.Load().Debug().Str("method", "validInterfaces").Str("interface", name).Msg("skipping non-hardware interface")
|
||||
continue
|
||||
}
|
||||
|
||||
adapters = append(adapters, name)
|
||||
}
|
||||
return adapters
|
||||
}
|
||||
42
cmd/cli/net_windows_test.go
Normal file
42
cmd/cli/net_windows_test.go
Normal file
@@ -0,0 +1,42 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"slices"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func Test_validInterfaces(t *testing.T) {
|
||||
verbose = 3
|
||||
initConsoleLogging()
|
||||
start := time.Now()
|
||||
ifaces := validInterfaces()
|
||||
t.Logf("Using Windows API takes: %d", time.Since(start).Milliseconds())
|
||||
|
||||
start = time.Now()
|
||||
ifacesPowershell := validInterfacesPowershell()
|
||||
t.Logf("Using Powershell takes: %d", time.Since(start).Milliseconds())
|
||||
|
||||
slices.Sort(ifaces)
|
||||
slices.Sort(ifacesPowershell)
|
||||
if !slices.Equal(ifaces, ifacesPowershell) {
|
||||
t.Fatalf("result mismatch, want: %v, got: %v", ifacesPowershell, ifaces)
|
||||
}
|
||||
}
|
||||
|
||||
func validInterfacesPowershell() []string {
|
||||
out, err := powershell("Get-NetAdapter -Physical | Select-Object -ExpandProperty Name")
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
var res []string
|
||||
scanner := bufio.NewScanner(bytes.NewReader(out))
|
||||
for scanner.Scan() {
|
||||
ifaceName := strings.TrimSpace(scanner.Text())
|
||||
res = append(res, ifaceName)
|
||||
}
|
||||
return res
|
||||
}
|
||||
@@ -1,8 +1,9 @@
|
||||
package main
|
||||
package cli
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
@@ -16,45 +17,56 @@ const (
|
||||
dns=none
|
||||
systemd-resolved=false
|
||||
`
|
||||
nmSystemdUnitName = "NetworkManager.service"
|
||||
systemdEnabledState = "enabled"
|
||||
nmSystemdUnitName = "NetworkManager.service"
|
||||
)
|
||||
|
||||
var networkManagerCtrldConfFile = filepath.Join(nmConfDir, nmCtrldConfFilename)
|
||||
|
||||
// hasNetworkManager reports whether NetworkManager executable found.
|
||||
func hasNetworkManager() bool {
|
||||
exe, _ := exec.LookPath("NetworkManager")
|
||||
return exe != ""
|
||||
}
|
||||
|
||||
func setupNetworkManager() error {
|
||||
if !hasNetworkManager() {
|
||||
return nil
|
||||
}
|
||||
if content, _ := os.ReadFile(nmCtrldConfContent); string(content) == nmCtrldConfContent {
|
||||
mainLog.Debug().Msg("NetworkManager already setup, nothing to do")
|
||||
mainLog.Load().Debug().Msg("NetworkManager already setup, nothing to do")
|
||||
return nil
|
||||
}
|
||||
err := os.WriteFile(networkManagerCtrldConfFile, []byte(nmCtrldConfContent), os.FileMode(0644))
|
||||
if os.IsNotExist(err) {
|
||||
mainLog.Debug().Msg("NetworkManager is not available")
|
||||
mainLog.Load().Debug().Msg("NetworkManager is not available")
|
||||
return nil
|
||||
}
|
||||
if err != nil {
|
||||
mainLog.Debug().Err(err).Msg("could not write NetworkManager ctrld config file")
|
||||
mainLog.Load().Debug().Err(err).Msg("could not write NetworkManager ctrld config file")
|
||||
return err
|
||||
}
|
||||
|
||||
reloadNetworkManager()
|
||||
mainLog.Debug().Msg("setup NetworkManager done")
|
||||
mainLog.Load().Debug().Msg("setup NetworkManager done")
|
||||
return nil
|
||||
}
|
||||
|
||||
func restoreNetworkManager() error {
|
||||
if !hasNetworkManager() {
|
||||
return nil
|
||||
}
|
||||
err := os.Remove(networkManagerCtrldConfFile)
|
||||
if os.IsNotExist(err) {
|
||||
mainLog.Debug().Msg("NetworkManager is not available")
|
||||
mainLog.Load().Debug().Msg("NetworkManager is not available")
|
||||
return nil
|
||||
}
|
||||
if err != nil {
|
||||
mainLog.Debug().Err(err).Msg("could not remove NetworkManager ctrld config file")
|
||||
mainLog.Load().Debug().Err(err).Msg("could not remove NetworkManager ctrld config file")
|
||||
return err
|
||||
}
|
||||
|
||||
reloadNetworkManager()
|
||||
mainLog.Debug().Msg("restore NetworkManager done")
|
||||
mainLog.Load().Debug().Msg("restore NetworkManager done")
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -63,14 +75,15 @@ func reloadNetworkManager() {
|
||||
defer cancel()
|
||||
conn, err := dbus.NewSystemConnectionContext(ctx)
|
||||
if err != nil {
|
||||
mainLog.Error().Err(err).Msg("could not create new system connection")
|
||||
mainLog.Load().Error().Err(err).Msg("could not create new system connection")
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
waitCh := make(chan string)
|
||||
if _, err := conn.ReloadUnitContext(ctx, nmSystemdUnitName, "ignore-dependencies", waitCh); err != nil {
|
||||
mainLog.Debug().Err(err).Msg("could not reload NetworkManager")
|
||||
mainLog.Load().Debug().Err(err).Msg("could not reload NetworkManager")
|
||||
return
|
||||
}
|
||||
<-waitCh
|
||||
}
|
||||
@@ -1,6 +1,6 @@
|
||||
//go:build !linux
|
||||
|
||||
package main
|
||||
package cli
|
||||
|
||||
func setupNetworkManager() error {
|
||||
reloadNetworkManager()
|
||||
31
cmd/cli/nextdns.go
Normal file
31
cmd/cli/nextdns.go
Normal file
@@ -0,0 +1,31 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld"
|
||||
)
|
||||
|
||||
const nextdnsURL = "https://dns.nextdns.io"
|
||||
|
||||
func generateNextDNSConfig(uid string) {
|
||||
if uid == "" {
|
||||
return
|
||||
}
|
||||
mainLog.Load().Info().Msg("generating ctrld config for NextDNS resolver")
|
||||
cfg = ctrld.Config{
|
||||
Listener: map[string]*ctrld.ListenerConfig{
|
||||
"0": {
|
||||
IP: "0.0.0.0",
|
||||
Port: 53,
|
||||
},
|
||||
},
|
||||
Upstream: map[string]*ctrld.UpstreamConfig{
|
||||
"0": {
|
||||
Type: ctrld.ResolverTypeDOH3,
|
||||
Endpoint: fmt.Sprintf("%s/%s", nextdnsURL, uid),
|
||||
Timeout: 5000,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
5
cmd/cli/nocgo.go
Normal file
5
cmd/cli/nocgo.go
Normal file
@@ -0,0 +1,5 @@
|
||||
//go:build !cgo
|
||||
|
||||
package cli
|
||||
|
||||
const cgoEnabled = false
|
||||
114
cmd/cli/os_darwin.go
Normal file
114
cmd/cli/os_darwin.go
Normal file
@@ -0,0 +1,114 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"fmt"
|
||||
"net"
|
||||
"os/exec"
|
||||
"strings"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld/internal/resolvconffile"
|
||||
)
|
||||
|
||||
// allocate loopback ip
|
||||
// sudo ifconfig lo0 alias 127.0.0.2 up
|
||||
func allocateIP(ip string) error {
|
||||
cmd := exec.Command("ifconfig", "lo0", "alias", ip, "up")
|
||||
if err := cmd.Run(); err != nil {
|
||||
mainLog.Load().Error().Err(err).Msg("allocateIP failed")
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func deAllocateIP(ip string) error {
|
||||
cmd := exec.Command("ifconfig", "lo0", "-alias", ip)
|
||||
if err := cmd.Run(); err != nil {
|
||||
mainLog.Load().Error().Err(err).Msg("deAllocateIP failed")
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// setDnsIgnoreUnusableInterface likes setDNS, but return a nil error if the interface is not usable.
|
||||
func setDnsIgnoreUnusableInterface(iface *net.Interface, nameservers []string) error {
|
||||
if err := setDNS(iface, nameservers); err != nil {
|
||||
// TODO: investiate whether we can detect this without relying on error message.
|
||||
if strings.Contains(err.Error(), " is not a recognized network service") {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// set the dns server for the provided network interface
|
||||
// networksetup -setdnsservers Wi-Fi 8.8.8.8 1.1.1.1
|
||||
// TODO(cuonglm): use system API
|
||||
func setDNS(iface *net.Interface, nameservers []string) error {
|
||||
// Note that networksetup won't modify search domains settings,
|
||||
// This assignment is just a placeholder to silent linter.
|
||||
_ = searchDomains
|
||||
cmd := "networksetup"
|
||||
args := []string{"-setdnsservers", iface.Name}
|
||||
args = append(args, nameservers...)
|
||||
if out, err := exec.Command(cmd, args...).CombinedOutput(); err != nil {
|
||||
return fmt.Errorf("%v: %w", string(out), err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// resetDnsIgnoreUnusableInterface likes resetDNS, but return a nil error if the interface is not usable.
|
||||
func resetDnsIgnoreUnusableInterface(iface *net.Interface) error {
|
||||
if err := resetDNS(iface); err != nil {
|
||||
// TODO: investiate whether we can detect this without relying on error message.
|
||||
if strings.Contains(err.Error(), " is not a recognized network service") {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// TODO(cuonglm): use system API
|
||||
func resetDNS(iface *net.Interface) error {
|
||||
cmd := "networksetup"
|
||||
args := []string{"-setdnsservers", iface.Name, "empty"}
|
||||
if out, err := exec.Command(cmd, args...).CombinedOutput(); err != nil {
|
||||
return fmt.Errorf("%v: %w", string(out), err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// restoreDNS restores the DNS settings of the given interface.
|
||||
// this should only be executed upon turning off the ctrld service.
|
||||
func restoreDNS(iface *net.Interface) (err error) {
|
||||
if ns := savedStaticNameservers(iface); len(ns) > 0 {
|
||||
err = setDNS(iface, ns)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func currentDNS(_ *net.Interface) []string {
|
||||
return resolvconffile.NameServers()
|
||||
}
|
||||
|
||||
// currentStaticDNS returns the current static DNS settings of given interface.
|
||||
func currentStaticDNS(iface *net.Interface) ([]string, error) {
|
||||
cmd := "networksetup"
|
||||
args := []string{"-getdnsservers", iface.Name}
|
||||
out, err := exec.Command(cmd, args...).Output()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
scanner := bufio.NewScanner(bytes.NewReader(out))
|
||||
var ns []string
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if ip := net.ParseIP(line); ip != nil {
|
||||
ns = append(ns, ip.String())
|
||||
}
|
||||
}
|
||||
return ns, nil
|
||||
}
|
||||
103
cmd/cli/os_freebsd.go
Normal file
103
cmd/cli/os_freebsd.go
Normal file
@@ -0,0 +1,103 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"os/exec"
|
||||
|
||||
"tailscale.com/control/controlknobs"
|
||||
"tailscale.com/health"
|
||||
"tailscale.com/util/dnsname"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld/internal/dns"
|
||||
"github.com/Control-D-Inc/ctrld/internal/resolvconffile"
|
||||
)
|
||||
|
||||
// allocate loopback ip
|
||||
// sudo ifconfig lo0 127.0.0.53 alias
|
||||
func allocateIP(ip string) error {
|
||||
cmd := exec.Command("ifconfig", "lo0", ip, "alias")
|
||||
if err := cmd.Run(); err != nil {
|
||||
mainLog.Load().Error().Err(err).Msg("allocateIP failed")
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func deAllocateIP(ip string) error {
|
||||
cmd := exec.Command("ifconfig", "lo0", ip, "-alias")
|
||||
if err := cmd.Run(); err != nil {
|
||||
mainLog.Load().Error().Err(err).Msg("deAllocateIP failed")
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// setDnsIgnoreUnusableInterface likes setDNS, but return a nil error if the interface is not usable.
|
||||
func setDnsIgnoreUnusableInterface(iface *net.Interface, nameservers []string) error {
|
||||
return setDNS(iface, nameservers)
|
||||
}
|
||||
|
||||
// set the dns server for the provided network interface
|
||||
func setDNS(iface *net.Interface, nameservers []string) error {
|
||||
r, err := dns.NewOSConfigurator(logf, &health.Tracker{}, &controlknobs.Knobs{}, iface.Name)
|
||||
if err != nil {
|
||||
mainLog.Load().Error().Err(err).Msg("failed to create DNS OS configurator")
|
||||
return err
|
||||
}
|
||||
|
||||
ns := make([]netip.Addr, 0, len(nameservers))
|
||||
for _, nameserver := range nameservers {
|
||||
ns = append(ns, netip.MustParseAddr(nameserver))
|
||||
}
|
||||
|
||||
osConfig := dns.OSConfig{
|
||||
Nameservers: ns,
|
||||
SearchDomains: []dnsname.FQDN{},
|
||||
}
|
||||
if sds, err := searchDomains(); err == nil {
|
||||
osConfig.SearchDomains = sds
|
||||
} else {
|
||||
mainLog.Load().Debug().Err(err).Msg("failed to get search domains list")
|
||||
}
|
||||
|
||||
if err := r.SetDNS(osConfig); err != nil {
|
||||
mainLog.Load().Error().Err(err).Msg("failed to set DNS")
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// resetDnsIgnoreUnusableInterface likes resetDNS, but return a nil error if the interface is not usable.
|
||||
func resetDnsIgnoreUnusableInterface(iface *net.Interface) error {
|
||||
return resetDNS(iface)
|
||||
}
|
||||
|
||||
func resetDNS(iface *net.Interface) error {
|
||||
r, err := dns.NewOSConfigurator(logf, &health.Tracker{}, &controlknobs.Knobs{}, iface.Name)
|
||||
if err != nil {
|
||||
mainLog.Load().Error().Err(err).Msg("failed to create DNS OS configurator")
|
||||
return err
|
||||
}
|
||||
|
||||
if err := r.Close(); err != nil {
|
||||
mainLog.Load().Error().Err(err).Msg("failed to rollback DNS setting")
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// restoreDNS restores the DNS settings of the given interface.
|
||||
// this should only be executed upon turning off the ctrld service.
|
||||
func restoreDNS(iface *net.Interface) (err error) {
|
||||
return err
|
||||
}
|
||||
|
||||
func currentDNS(_ *net.Interface) []string {
|
||||
return resolvconffile.NameServers()
|
||||
}
|
||||
|
||||
// currentStaticDNS returns the current static DNS settings of given interface.
|
||||
func currentStaticDNS(iface *net.Interface) ([]string, error) {
|
||||
return currentDNS(iface), nil
|
||||
}
|
||||
314
cmd/cli/os_linux.go
Normal file
314
cmd/cli/os_linux.go
Normal file
@@ -0,0 +1,314 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os/exec"
|
||||
"slices"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/insomniacslk/dhcp/dhcpv4/nclient4"
|
||||
"github.com/insomniacslk/dhcp/dhcpv6"
|
||||
"github.com/insomniacslk/dhcp/dhcpv6/client6"
|
||||
"tailscale.com/control/controlknobs"
|
||||
"tailscale.com/health"
|
||||
"tailscale.com/util/dnsname"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld/internal/dns"
|
||||
ctrldnet "github.com/Control-D-Inc/ctrld/internal/net"
|
||||
"github.com/Control-D-Inc/ctrld/internal/resolvconffile"
|
||||
)
|
||||
|
||||
const resolvConfBackupFailedMsg = "open /etc/resolv.pre-ctrld-backup.conf: read-only file system"
|
||||
|
||||
// allocate loopback ip
|
||||
// sudo ip a add 127.0.0.2/24 dev lo
|
||||
func allocateIP(ip string) error {
|
||||
cmd := exec.Command("ip", "a", "add", ip+"/24", "dev", "lo")
|
||||
if out, err := cmd.CombinedOutput(); err != nil {
|
||||
mainLog.Load().Error().Err(err).Msgf("allocateIP failed: %s", string(out))
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func deAllocateIP(ip string) error {
|
||||
cmd := exec.Command("ip", "a", "del", ip+"/24", "dev", "lo")
|
||||
if err := cmd.Run(); err != nil {
|
||||
mainLog.Load().Error().Err(err).Msg("deAllocateIP failed")
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
const maxSetDNSAttempts = 5
|
||||
|
||||
// setDnsIgnoreUnusableInterface likes setDNS, but return a nil error if the interface is not usable.
|
||||
func setDnsIgnoreUnusableInterface(iface *net.Interface, nameservers []string) error {
|
||||
return setDNS(iface, nameservers)
|
||||
}
|
||||
|
||||
func setDNS(iface *net.Interface, nameservers []string) error {
|
||||
r, err := dns.NewOSConfigurator(logf, &health.Tracker{}, &controlknobs.Knobs{}, iface.Name)
|
||||
if err != nil {
|
||||
mainLog.Load().Error().Err(err).Msg("failed to create DNS OS configurator")
|
||||
return err
|
||||
}
|
||||
|
||||
ns := make([]netip.Addr, 0, len(nameservers))
|
||||
for _, nameserver := range nameservers {
|
||||
ns = append(ns, netip.MustParseAddr(nameserver))
|
||||
}
|
||||
|
||||
osConfig := dns.OSConfig{
|
||||
Nameservers: ns,
|
||||
SearchDomains: []dnsname.FQDN{},
|
||||
}
|
||||
if sds, err := searchDomains(); err == nil {
|
||||
// Filter the root domain, since it's not allowed by systemd.
|
||||
// See https://github.com/systemd/systemd/issues/9515
|
||||
filteredSds := slices.DeleteFunc(sds, func(s dnsname.FQDN) bool {
|
||||
return s == "" || s == "."
|
||||
})
|
||||
if len(filteredSds) != len(sds) {
|
||||
mainLog.Load().Debug().Msg(`Removed root domain "." from search domains list`)
|
||||
}
|
||||
osConfig.SearchDomains = filteredSds
|
||||
} else {
|
||||
mainLog.Load().Debug().Err(err).Msg("failed to get search domains list")
|
||||
}
|
||||
trySystemdResolve := false
|
||||
if err := r.SetDNS(osConfig); err != nil {
|
||||
if strings.Contains(err.Error(), "Rejected send message") &&
|
||||
strings.Contains(err.Error(), "org.freedesktop.network1.Manager") {
|
||||
mainLog.Load().Warn().Msg("Interfaces are managed by systemd-networkd, switch to systemd-resolve for setting DNS")
|
||||
trySystemdResolve = true
|
||||
goto systemdResolve
|
||||
}
|
||||
// This error happens on read-only file system, which causes ctrld failed to create backup
|
||||
// for /etc/resolv.conf file. It is ok, because the DNS is still set anyway, and restore
|
||||
// DNS will fallback to use DHCP if there's no backup /etc/resolv.conf file.
|
||||
// The error format is controlled by us, so checking for error string is fine.
|
||||
// See: ../../internal/dns/direct.go:L278
|
||||
if r.Mode() == "direct" && strings.Contains(err.Error(), resolvConfBackupFailedMsg) {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
systemdResolve:
|
||||
if trySystemdResolve {
|
||||
// Stop systemd-networkd and retry setting DNS.
|
||||
if out, err := exec.Command("systemctl", "stop", "systemd-networkd").CombinedOutput(); err != nil {
|
||||
return fmt.Errorf("%s: %w", string(out), err)
|
||||
}
|
||||
args := []string{"--interface=" + iface.Name, "--set-domain=~"}
|
||||
for _, nameserver := range nameservers {
|
||||
args = append(args, "--set-dns="+nameserver)
|
||||
}
|
||||
for i := 0; i < maxSetDNSAttempts; i++ {
|
||||
if out, err := exec.Command("systemd-resolve", args...).CombinedOutput(); err != nil {
|
||||
return fmt.Errorf("%s: %w", string(out), err)
|
||||
}
|
||||
currentNS := currentDNS(iface)
|
||||
if isSubSet(nameservers, currentNS) {
|
||||
return nil
|
||||
}
|
||||
time.Sleep(time.Second)
|
||||
}
|
||||
mainLog.Load().Debug().Msg("DNS was not set for some reason")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// resetDnsIgnoreUnusableInterface likes resetDNS, but return a nil error if the interface is not usable.
|
||||
func resetDnsIgnoreUnusableInterface(iface *net.Interface) error {
|
||||
return resetDNS(iface)
|
||||
}
|
||||
|
||||
func resetDNS(iface *net.Interface) (err error) {
|
||||
defer func() {
|
||||
if err == nil {
|
||||
return
|
||||
}
|
||||
// Start systemd-networkd if present.
|
||||
if exe, _ := exec.LookPath("/lib/systemd/systemd-networkd"); exe != "" {
|
||||
_ = exec.Command("systemctl", "start", "systemd-networkd").Run()
|
||||
}
|
||||
if r, oerr := dns.NewOSConfigurator(logf, &health.Tracker{}, &controlknobs.Knobs{}, iface.Name); oerr == nil {
|
||||
_ = r.SetDNS(dns.OSConfig{})
|
||||
if err := r.Close(); err != nil {
|
||||
mainLog.Load().Error().Err(err).Msg("failed to rollback DNS setting")
|
||||
return
|
||||
}
|
||||
err = nil
|
||||
}
|
||||
}()
|
||||
|
||||
var ns []string
|
||||
c, err := nclient4.New(iface.Name)
|
||||
if err != nil {
|
||||
return fmt.Errorf("nclient4.New: %w", err)
|
||||
}
|
||||
defer c.Close()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
|
||||
defer cancel()
|
||||
lease, err := c.Request(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("nclient4.Request: %w", err)
|
||||
}
|
||||
for _, nameserver := range lease.ACK.DNS() {
|
||||
if nameserver.Equal(net.IPv4zero) {
|
||||
continue
|
||||
}
|
||||
ns = append(ns, nameserver.String())
|
||||
}
|
||||
|
||||
// TODO(cuonglm): handle DHCPv6 properly.
|
||||
mainLog.Load().Debug().Msg("checking for IPv6 availability")
|
||||
if ctrldnet.IPv6Available(ctx) {
|
||||
c := client6.NewClient()
|
||||
conversation, err := c.Exchange(iface.Name)
|
||||
if err != nil && !errAddrInUse(err) {
|
||||
mainLog.Load().Debug().Err(err).Msg("could not exchange DHCPv6")
|
||||
}
|
||||
for _, packet := range conversation {
|
||||
if packet.Type() == dhcpv6.MessageTypeReply {
|
||||
msg, err := packet.GetInnerMessage()
|
||||
if err != nil {
|
||||
mainLog.Load().Debug().Err(err).Msg("could not get inner DHCPv6 message")
|
||||
return nil
|
||||
}
|
||||
nameservers := msg.Options.DNS()
|
||||
for _, nameserver := range nameservers {
|
||||
ns = append(ns, nameserver.String())
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
mainLog.Load().Debug().Msg("IPv6 is not available")
|
||||
}
|
||||
|
||||
return ignoringEINTR(func() error {
|
||||
return setDNS(iface, ns)
|
||||
})
|
||||
}
|
||||
|
||||
// restoreDNS restores the DNS settings of the given interface.
|
||||
// this should only be executed upon turning off the ctrld service.
|
||||
func restoreDNS(iface *net.Interface) (err error) {
|
||||
return err
|
||||
}
|
||||
|
||||
func currentDNS(iface *net.Interface) []string {
|
||||
resolvconfFunc := func(_ string) []string { return resolvconffile.NameServers() }
|
||||
for _, fn := range []getDNS{getDNSByResolvectl, getDNSBySystemdResolved, getDNSByNmcli, resolvconfFunc} {
|
||||
if ns := fn(iface.Name); len(ns) > 0 {
|
||||
return ns
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// currentStaticDNS returns the current static DNS settings of given interface.
|
||||
func currentStaticDNS(iface *net.Interface) ([]string, error) {
|
||||
return currentDNS(iface), nil
|
||||
}
|
||||
|
||||
func getDNSByResolvectl(iface string) []string {
|
||||
b, err := exec.Command("resolvectl", "dns", "-i", iface).Output()
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
parts := strings.Fields(strings.SplitN(string(b), "%", 2)[0])
|
||||
if len(parts) > 2 {
|
||||
return parts[3:]
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func getDNSBySystemdResolved(iface string) []string {
|
||||
b, err := exec.Command("systemd-resolve", "--status", iface).Output()
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
return getDNSBySystemdResolvedFromReader(bytes.NewReader(b))
|
||||
}
|
||||
|
||||
func getDNSBySystemdResolvedFromReader(r io.Reader) []string {
|
||||
scanner := bufio.NewScanner(r)
|
||||
var ret []string
|
||||
for scanner.Scan() {
|
||||
line := strings.TrimSpace(scanner.Text())
|
||||
if len(ret) > 0 {
|
||||
if net.ParseIP(line) != nil {
|
||||
ret = append(ret, line)
|
||||
}
|
||||
continue
|
||||
}
|
||||
after, found := strings.CutPrefix(line, "DNS Servers: ")
|
||||
if !found {
|
||||
continue
|
||||
}
|
||||
if net.ParseIP(after) != nil {
|
||||
ret = append(ret, after)
|
||||
}
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
func getDNSByNmcli(iface string) []string {
|
||||
b, err := exec.Command("nmcli", "dev", "show", iface).Output()
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
s := bufio.NewScanner(bytes.NewReader(b))
|
||||
var dns []string
|
||||
do := func(line string) {
|
||||
parts := strings.SplitN(line, ":", 2)
|
||||
if len(parts) > 1 {
|
||||
dns = append(dns, strings.TrimSpace(parts[1]))
|
||||
}
|
||||
}
|
||||
for s.Scan() {
|
||||
line := s.Text()
|
||||
switch {
|
||||
case strings.HasPrefix(line, "IP4.DNS"):
|
||||
fallthrough
|
||||
case strings.HasPrefix(line, "IP6.DNS"):
|
||||
do(line)
|
||||
}
|
||||
}
|
||||
return dns
|
||||
}
|
||||
|
||||
func ignoringEINTR(fn func() error) error {
|
||||
for {
|
||||
err := fn()
|
||||
if err != syscall.EINTR {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// isSubSet reports whether s2 contains all elements of s1.
|
||||
func isSubSet(s1, s2 []string) bool {
|
||||
ok := true
|
||||
for _, ns := range s1 {
|
||||
if slices.Contains(s2, ns) {
|
||||
continue
|
||||
}
|
||||
ok = false
|
||||
break
|
||||
}
|
||||
return ok
|
||||
}
|
||||
23
cmd/cli/os_linux_test.go
Normal file
23
cmd/cli/os_linux_test.go
Normal file
@@ -0,0 +1,23 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func Test_getDNSBySystemdResolvedFromReader(t *testing.T) {
|
||||
r := strings.NewReader(`Link 2 (eth0)
|
||||
Current Scopes: DNS
|
||||
LLMNR setting: yes
|
||||
MulticastDNS setting: no
|
||||
DNSSEC setting: no
|
||||
DNSSEC supported: no
|
||||
DNS Servers: 8.8.8.8
|
||||
8.8.4.4`)
|
||||
want := []string{"8.8.8.8", "8.8.4.4"}
|
||||
ns := getDNSBySystemdResolvedFromReader(r)
|
||||
if !reflect.DeepEqual(ns, want) {
|
||||
t.Logf("unexpected result, want: %v, got: %v", want, ns)
|
||||
}
|
||||
}
|
||||
@@ -1,6 +1,6 @@
|
||||
//go:build !linux && !darwin && !freebsd
|
||||
|
||||
package main
|
||||
package cli
|
||||
|
||||
// TODO(cuonglm): implement.
|
||||
func allocateIP(ip string) error {
|
||||
332
cmd/cli/os_windows.go
Normal file
332
cmd/cli/os_windows.go
Normal file
@@ -0,0 +1,332 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"os/exec"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"golang.org/x/sys/windows"
|
||||
"golang.org/x/sys/windows/registry"
|
||||
"golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
|
||||
|
||||
ctrldnet "github.com/Control-D-Inc/ctrld/internal/net"
|
||||
)
|
||||
|
||||
const (
|
||||
v4InterfaceKeyPathFormat = `SYSTEM\CurrentControlSet\Services\Tcpip\Parameters\Interfaces\`
|
||||
v6InterfaceKeyPathFormat = `SYSTEM\CurrentControlSet\Services\Tcpip6\Parameters\Interfaces\`
|
||||
)
|
||||
|
||||
var (
|
||||
setDNSOnce sync.Once
|
||||
resetDNSOnce sync.Once
|
||||
)
|
||||
|
||||
// setDnsIgnoreUnusableInterface likes setDNS, but return a nil error if the interface is not usable.
|
||||
func setDnsIgnoreUnusableInterface(iface *net.Interface, nameservers []string) error {
|
||||
return setDNS(iface, nameservers)
|
||||
}
|
||||
|
||||
// setDNS sets the dns server for the provided network interface
|
||||
func setDNS(iface *net.Interface, nameservers []string) error {
|
||||
if len(nameservers) == 0 {
|
||||
return errors.New("empty DNS nameservers")
|
||||
}
|
||||
setDNSOnce.Do(func() {
|
||||
// If there's a Dns server running, that means we are on AD with Dns feature enabled.
|
||||
// Configuring the Dns server to forward queries to ctrld instead.
|
||||
if hasLocalDnsServerRunning() {
|
||||
mainLog.Load().Debug().Msg("Local DNS server detected, configuring forwarders")
|
||||
|
||||
file := absHomeDir(windowsForwardersFilename)
|
||||
mainLog.Load().Debug().Msgf("Using forwarders file: %s", file)
|
||||
|
||||
oldForwardersContent, err := os.ReadFile(file)
|
||||
if err != nil {
|
||||
mainLog.Load().Debug().Err(err).Msg("Could not read existing forwarders file")
|
||||
} else {
|
||||
mainLog.Load().Debug().Msgf("Existing forwarders content: %s", string(oldForwardersContent))
|
||||
}
|
||||
|
||||
hasLocalIPv6Listener := needLocalIPv6Listener(interceptMode)
|
||||
mainLog.Load().Debug().Bool("has_ipv6_listener", hasLocalIPv6Listener).Msg("IPv6 listener status")
|
||||
|
||||
forwarders := slices.DeleteFunc(slices.Clone(nameservers), func(s string) bool {
|
||||
if !hasLocalIPv6Listener {
|
||||
return false
|
||||
}
|
||||
return s == "::1"
|
||||
})
|
||||
mainLog.Load().Debug().Strs("forwarders", forwarders).Msg("Filtered forwarders list")
|
||||
|
||||
if err := os.WriteFile(file, []byte(strings.Join(forwarders, ",")), 0600); err != nil {
|
||||
mainLog.Load().Warn().Err(err).Msg("could not save forwarders settings")
|
||||
} else {
|
||||
mainLog.Load().Debug().Msg("Successfully wrote new forwarders file")
|
||||
}
|
||||
|
||||
oldForwarders := strings.Split(string(oldForwardersContent), ",")
|
||||
mainLog.Load().Debug().Strs("old_forwarders", oldForwarders).Msg("Previous forwarders")
|
||||
|
||||
if err := addDnsServerForwarders(forwarders, oldForwarders); err != nil {
|
||||
mainLog.Load().Warn().Err(err).Msg("could not set forwarders settings")
|
||||
} else {
|
||||
mainLog.Load().Debug().Msg("Successfully configured DNS server forwarders")
|
||||
}
|
||||
}
|
||||
})
|
||||
luid, err := winipcfg.LUIDFromIndex(uint32(iface.Index))
|
||||
if err != nil {
|
||||
return fmt.Errorf("setDNS: %w", err)
|
||||
}
|
||||
var (
|
||||
serversV4 []netip.Addr
|
||||
serversV6 []netip.Addr
|
||||
)
|
||||
for _, ns := range nameservers {
|
||||
if addr, err := netip.ParseAddr(ns); err == nil {
|
||||
if addr.Is4() {
|
||||
serversV4 = append(serversV4, addr)
|
||||
} else {
|
||||
serversV6 = append(serversV6, addr)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Note that Windows won't modify the current search domains if passing nil to luid.SetDNS function.
|
||||
// searchDomains is still implemented for Windows just in case Windows API changes in future versions.
|
||||
_ = searchDomains
|
||||
|
||||
if len(serversV4) == 0 && len(serversV6) == 0 {
|
||||
return errors.New("invalid DNS nameservers")
|
||||
}
|
||||
if len(serversV4) > 0 {
|
||||
if err := luid.SetDNS(windows.AF_INET, serversV4, nil); err != nil {
|
||||
return fmt.Errorf("could not set DNS ipv4: %w", err)
|
||||
}
|
||||
}
|
||||
if len(serversV6) > 0 {
|
||||
if err := luid.SetDNS(windows.AF_INET6, serversV6, nil); err != nil {
|
||||
return fmt.Errorf("could not set DNS ipv6: %w", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// resetDnsIgnoreUnusableInterface likes resetDNS, but return a nil error if the interface is not usable.
|
||||
func resetDnsIgnoreUnusableInterface(iface *net.Interface) error {
|
||||
return resetDNS(iface)
|
||||
}
|
||||
|
||||
// TODO(cuonglm): should we use system API?
|
||||
func resetDNS(iface *net.Interface) error {
|
||||
resetDNSOnce.Do(func() {
|
||||
// See corresponding comment in setDNS.
|
||||
if hasLocalDnsServerRunning() {
|
||||
file := absHomeDir(windowsForwardersFilename)
|
||||
content, err := os.ReadFile(file)
|
||||
if err != nil {
|
||||
mainLog.Load().Error().Err(err).Msg("could not read forwarders settings")
|
||||
return
|
||||
}
|
||||
nameservers := strings.Split(string(content), ",")
|
||||
if err := removeDnsServerForwarders(nameservers); err != nil {
|
||||
mainLog.Load().Error().Err(err).Msg("could not remove forwarders settings")
|
||||
return
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
luid, err := winipcfg.LUIDFromIndex(uint32(iface.Index))
|
||||
if err != nil {
|
||||
return fmt.Errorf("resetDNS: %w", err)
|
||||
}
|
||||
// Restoring DHCP settings.
|
||||
if err := luid.SetDNS(windows.AF_INET, nil, nil); err != nil {
|
||||
return fmt.Errorf("could not reset DNS ipv4: %w", err)
|
||||
}
|
||||
if err := luid.SetDNS(windows.AF_INET6, nil, nil); err != nil {
|
||||
return fmt.Errorf("could not reset DNS ipv6: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// restoreDNS restores the DNS settings of the given interface.
|
||||
// this should only be executed upon turning off the ctrld service.
|
||||
func restoreDNS(iface *net.Interface) (err error) {
|
||||
if nss := savedStaticNameservers(iface); len(nss) > 0 {
|
||||
v4ns := make([]string, 0, 2)
|
||||
v6ns := make([]string, 0, 2)
|
||||
for _, ns := range nss {
|
||||
if ctrldnet.IsIPv6(ns) {
|
||||
v6ns = append(v6ns, ns)
|
||||
} else {
|
||||
v4ns = append(v4ns, ns)
|
||||
}
|
||||
}
|
||||
|
||||
luid, err := winipcfg.LUIDFromIndex(uint32(iface.Index))
|
||||
if err != nil {
|
||||
return fmt.Errorf("restoreDNS: %w", err)
|
||||
}
|
||||
|
||||
if len(v4ns) > 0 {
|
||||
mainLog.Load().Debug().Msgf("restoring IPv4 static DNS for interface %q: %v", iface.Name, v4ns)
|
||||
if err := setDNS(iface, v4ns); err != nil {
|
||||
return fmt.Errorf("restoreDNS (IPv4): %w", err)
|
||||
}
|
||||
} else {
|
||||
mainLog.Load().Debug().Msgf("restoring IPv4 DHCP for interface %q", iface.Name)
|
||||
if err := luid.SetDNS(windows.AF_INET, nil, nil); err != nil {
|
||||
return fmt.Errorf("restoreDNS (IPv4 clear): %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if len(v6ns) > 0 {
|
||||
mainLog.Load().Debug().Msgf("restoring IPv6 static DNS for interface %q: %v", iface.Name, v6ns)
|
||||
if err := setDNS(iface, v6ns); err != nil {
|
||||
return fmt.Errorf("restoreDNS (IPv6): %w", err)
|
||||
}
|
||||
} else {
|
||||
mainLog.Load().Debug().Msgf("restoring IPv6 DHCP for interface %q", iface.Name)
|
||||
if err := luid.SetDNS(windows.AF_INET6, nil, nil); err != nil {
|
||||
return fmt.Errorf("restoreDNS (IPv6 clear): %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func currentDNS(iface *net.Interface) []string {
|
||||
luid, err := winipcfg.LUIDFromIndex(uint32(iface.Index))
|
||||
if err != nil {
|
||||
mainLog.Load().Error().Err(err).Msg("failed to get interface LUID")
|
||||
return nil
|
||||
}
|
||||
nameservers, err := luid.DNS()
|
||||
if err != nil {
|
||||
mainLog.Load().Error().Err(err).Msg("failed to get interface DNS")
|
||||
return nil
|
||||
}
|
||||
ns := make([]string, 0, len(nameservers))
|
||||
for _, nameserver := range nameservers {
|
||||
ns = append(ns, nameserver.String())
|
||||
}
|
||||
return ns
|
||||
}
|
||||
|
||||
// currentStaticDNS checks both the IPv4 and IPv6 paths for static DNS values using keys
|
||||
// like "NameServer" and "ProfileNameServer".
|
||||
func currentStaticDNS(iface *net.Interface) ([]string, error) {
|
||||
luid, err := winipcfg.LUIDFromIndex(uint32(iface.Index))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("fallback winipcfg.LUIDFromIndex: %w", err)
|
||||
}
|
||||
guid, err := luid.GUID()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("fallback luid.GUID: %w", err)
|
||||
}
|
||||
|
||||
var ns []string
|
||||
keyPaths := []string{v4InterfaceKeyPathFormat, v6InterfaceKeyPathFormat}
|
||||
for _, path := range keyPaths {
|
||||
interfaceKeyPath := path + guid.String()
|
||||
k, err := registry.OpenKey(registry.LOCAL_MACHINE, interfaceKeyPath, registry.QUERY_VALUE)
|
||||
if err != nil {
|
||||
mainLog.Load().Debug().Err(err).Msgf("failed to open registry key %q for interface %q; trying next key", interfaceKeyPath, iface.Name)
|
||||
continue
|
||||
}
|
||||
func() {
|
||||
defer k.Close()
|
||||
for _, keyName := range []string{"NameServer", "ProfileNameServer"} {
|
||||
value, _, err := k.GetStringValue(keyName)
|
||||
if err != nil && !errors.Is(err, registry.ErrNotExist) {
|
||||
mainLog.Load().Debug().Err(err).Msgf("error reading %s registry key", keyName)
|
||||
continue
|
||||
}
|
||||
if len(value) > 0 {
|
||||
mainLog.Load().Debug().Msgf("found static DNS for interface %q: %s", iface.Name, value)
|
||||
parsed := parseDNSServers(value)
|
||||
for _, pns := range parsed {
|
||||
if !slices.Contains(ns, pns) {
|
||||
ns = append(ns, pns)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
if len(ns) == 0 {
|
||||
mainLog.Load().Debug().Msgf("no static DNS values found for interface %q", iface.Name)
|
||||
}
|
||||
return ns, nil
|
||||
}
|
||||
|
||||
// parseDNSServers splits a DNS server string that may be comma- or space-separated,
|
||||
// and trims any extraneous whitespace or null characters.
|
||||
func parseDNSServers(val string) []string {
|
||||
fields := strings.FieldsFunc(val, func(r rune) bool {
|
||||
return r == ' ' || r == ','
|
||||
})
|
||||
var servers []string
|
||||
for _, f := range fields {
|
||||
trimmed := strings.TrimSpace(f)
|
||||
if len(trimmed) > 0 {
|
||||
servers = append(servers, trimmed)
|
||||
}
|
||||
}
|
||||
return servers
|
||||
}
|
||||
|
||||
// addDnsServerForwarders adds given nameservers to DNS server forwarders list,
|
||||
// and also removing old forwarders if provided.
|
||||
func addDnsServerForwarders(nameservers, old []string) error {
|
||||
newForwardersMap := make(map[string]struct{})
|
||||
newForwarders := make([]string, len(nameservers))
|
||||
for i := range nameservers {
|
||||
newForwardersMap[nameservers[i]] = struct{}{}
|
||||
newForwarders[i] = fmt.Sprintf("%q", nameservers[i])
|
||||
}
|
||||
oldForwarders := old[:0]
|
||||
for _, fwd := range old {
|
||||
if _, ok := newForwardersMap[fwd]; !ok {
|
||||
oldForwarders = append(oldForwarders, fwd)
|
||||
}
|
||||
}
|
||||
// NOTE: It is important to add new forwarder before removing old one.
|
||||
// Testing on Windows Server 2022 shows that removing forwarder1
|
||||
// then adding forwarder2 sometimes ends up adding both of them
|
||||
// to the forwarders list.
|
||||
cmd := fmt.Sprintf("Add-DnsServerForwarder -IPAddress %s", strings.Join(newForwarders, ","))
|
||||
if len(oldForwarders) > 0 {
|
||||
cmd = fmt.Sprintf("%s ; Remove-DnsServerForwarder -IPAddress %s -Force", cmd, strings.Join(oldForwarders, ","))
|
||||
}
|
||||
if out, err := powershell(cmd); err != nil {
|
||||
return fmt.Errorf("%w: %s", err, string(out))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// removeDnsServerForwarders removes given nameservers from DNS server forwarders list.
|
||||
func removeDnsServerForwarders(nameservers []string) error {
|
||||
for _, ns := range nameservers {
|
||||
cmd := fmt.Sprintf("Remove-DnsServerForwarder -IPAddress %s -Force", ns)
|
||||
if out, err := powershell(cmd); err != nil {
|
||||
return fmt.Errorf("%w: %s", err, string(out))
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// powershell runs the given powershell command.
|
||||
func powershell(cmd string) ([]byte, error) {
|
||||
out, err := exec.Command("powershell", "-Command", cmd).CombinedOutput()
|
||||
return bytes.TrimSpace(out), err
|
||||
}
|
||||
68
cmd/cli/os_windows_test.go
Normal file
68
cmd/cli/os_windows_test.go
Normal file
@@ -0,0 +1,68 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"slices"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
|
||||
)
|
||||
|
||||
func Test_currentStaticDNS(t *testing.T) {
|
||||
iface, err := net.InterfaceByName(defaultIfaceName())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
start := time.Now()
|
||||
staticDns, err := currentStaticDNS(iface)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Logf("Using Windows API takes: %d", time.Since(start).Milliseconds())
|
||||
|
||||
start = time.Now()
|
||||
staticDnsPowershell, err := currentStaticDnsPowershell(iface)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Logf("Using Powershell takes: %d", time.Since(start).Milliseconds())
|
||||
|
||||
slices.Sort(staticDns)
|
||||
slices.Sort(staticDnsPowershell)
|
||||
if !slices.Equal(staticDns, staticDnsPowershell) {
|
||||
t.Fatalf("result mismatch, want: %v, got: %v", staticDnsPowershell, staticDns)
|
||||
}
|
||||
}
|
||||
|
||||
func currentStaticDnsPowershell(iface *net.Interface) ([]string, error) {
|
||||
luid, err := winipcfg.LUIDFromIndex(uint32(iface.Index))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
guid, err := luid.GUID()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var ns []string
|
||||
for _, path := range []string{"HKLM:\\" + v4InterfaceKeyPathFormat, "HKLM:\\" + v6InterfaceKeyPathFormat} {
|
||||
interfaceKeyPath := path + guid.String()
|
||||
found := false
|
||||
for _, key := range []string{"NameServer", "ProfileNameServer"} {
|
||||
if found {
|
||||
continue
|
||||
}
|
||||
cmd := fmt.Sprintf(`Get-ItemPropertyValue -Path "%s" -Name "%s"`, interfaceKeyPath, key)
|
||||
out, err := powershell(cmd)
|
||||
if err == nil && len(out) > 0 {
|
||||
found = true
|
||||
for _, e := range strings.Split(string(out), ",") {
|
||||
ns = append(ns, strings.TrimRight(e, "\x00"))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return ns, nil
|
||||
}
|
||||
1710
cmd/cli/prog.go
Normal file
1710
cmd/cli/prog.go
Normal file
File diff suppressed because it is too large
Load Diff
@@ -1,23 +1,11 @@
|
||||
package main
|
||||
package cli
|
||||
|
||||
import (
|
||||
"github.com/kardianos/service"
|
||||
)
|
||||
|
||||
func (p *prog) preRun() {
|
||||
if !service.Interactive() {
|
||||
p.setDNS()
|
||||
}
|
||||
}
|
||||
|
||||
func setDependencies(svc *service.Config) {}
|
||||
|
||||
func setWorkingDirectory(svc *service.Config, dir string) {
|
||||
svc.WorkingDirectory = dir
|
||||
}
|
||||
|
||||
func (p *prog) preStop() {
|
||||
if !service.Interactive() {
|
||||
p.resetDNS()
|
||||
}
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package main
|
||||
package cli
|
||||
|
||||
import (
|
||||
"os"
|
||||
@@ -6,17 +6,9 @@ import (
|
||||
"github.com/kardianos/service"
|
||||
)
|
||||
|
||||
func (p *prog) preRun() {
|
||||
if !service.Interactive() {
|
||||
p.setDNS()
|
||||
}
|
||||
}
|
||||
|
||||
func setDependencies(svc *service.Config) {
|
||||
// TODO(cuonglm): remove once https://github.com/kardianos/service/issues/359 fixed.
|
||||
_ = os.MkdirAll("/usr/local/etc/rc.d", 0755)
|
||||
}
|
||||
|
||||
func setWorkingDirectory(svc *service.Config, dir string) {}
|
||||
|
||||
func (p *prog) preStop() {}
|
||||
65
cmd/cli/prog_linux.go
Normal file
65
cmd/cli/prog_linux.go
Normal file
@@ -0,0 +1,65 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"io"
|
||||
"os"
|
||||
"os/exec"
|
||||
"strings"
|
||||
|
||||
"github.com/kardianos/service"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld/internal/router"
|
||||
)
|
||||
|
||||
func init() {
|
||||
if r, err := newLoopbackOSConfigurator(); err == nil {
|
||||
useSystemdResolved = r.Mode() == "systemd-resolved"
|
||||
}
|
||||
// Disable quic-go's ECN support by default, see https://github.com/quic-go/quic-go/issues/3911
|
||||
if os.Getenv("QUIC_GO_DISABLE_ECN") == "" {
|
||||
os.Setenv("QUIC_GO_DISABLE_ECN", "true")
|
||||
}
|
||||
}
|
||||
|
||||
func setDependencies(svc *service.Config) {
|
||||
svc.Dependencies = []string{
|
||||
"Wants=network-online.target",
|
||||
"After=network-online.target",
|
||||
"Wants=NetworkManager-wait-online.service",
|
||||
"After=NetworkManager-wait-online.service",
|
||||
"Wants=nss-lookup.target",
|
||||
"After=nss-lookup.target",
|
||||
}
|
||||
if out, _ := exec.Command("networkctl", "--no-pager").CombinedOutput(); len(out) > 0 {
|
||||
if wantsSystemDNetworkdWaitOnline(bytes.NewReader(out)) {
|
||||
svc.Dependencies = append(svc.Dependencies, "Wants=systemd-networkd-wait-online.service")
|
||||
}
|
||||
}
|
||||
if routerDeps := router.ServiceDependencies(); len(routerDeps) > 0 {
|
||||
svc.Dependencies = append(svc.Dependencies, routerDeps...)
|
||||
}
|
||||
}
|
||||
|
||||
func setWorkingDirectory(svc *service.Config, dir string) {
|
||||
svc.WorkingDirectory = dir
|
||||
}
|
||||
|
||||
// wantsSystemDNetworkdWaitOnline reports whether "systemd-networkd-wait-online" service
|
||||
// is required to be added to ctrld dependencies services.
|
||||
// The input reader r is the output of "networkctl --no-pager" command.
|
||||
func wantsSystemDNetworkdWaitOnline(r io.Reader) bool {
|
||||
scanner := bufio.NewScanner(r)
|
||||
// Skip header
|
||||
scanner.Scan()
|
||||
configured := false
|
||||
for scanner.Scan() {
|
||||
fields := strings.Fields(scanner.Text())
|
||||
if len(fields) > 0 && fields[len(fields)-1] == "configured" {
|
||||
configured = true
|
||||
break
|
||||
}
|
||||
}
|
||||
return configured
|
||||
}
|
||||
48
cmd/cli/prog_linux_test.go
Normal file
48
cmd/cli/prog_linux_test.go
Normal file
@@ -0,0 +1,48 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"io"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
const (
|
||||
networkctlUnmanagedOutput = `IDX LINK TYPE OPERATIONAL SETUP
|
||||
1 lo loopback carrier unmanaged
|
||||
2 wlp0s20f3 wlan routable unmanaged
|
||||
3 tailscale0 none routable unmanaged
|
||||
4 br-9ac33145e060 bridge no-carrier unmanaged
|
||||
5 docker0 bridge no-carrier unmanaged
|
||||
|
||||
5 links listed.
|
||||
`
|
||||
networkctlManagedOutput = `IDX LINK TYPE OPERATIONAL SETUP
|
||||
1 lo loopback carrier unmanaged
|
||||
2 wlp0s20f3 wlan routable configured
|
||||
3 tailscale0 none routable unmanaged
|
||||
4 br-9ac33145e060 bridge no-carrier unmanaged
|
||||
5 docker0 bridge no-carrier unmanaged
|
||||
|
||||
5 links listed.
|
||||
`
|
||||
)
|
||||
|
||||
func Test_wantsSystemDNetworkdWaitOnline(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
r io.Reader
|
||||
required bool
|
||||
}{
|
||||
{"unmanaged", strings.NewReader(networkctlUnmanagedOutput), false},
|
||||
{"managed", strings.NewReader(networkctlManagedOutput), true},
|
||||
{"empty", strings.NewReader(""), false},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
if required := wantsSystemDNetworkdWaitOnline(tc.r); required != tc.required {
|
||||
t.Errorf("wants %v got %v", tc.required, required)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,16 +1,12 @@
|
||||
//go:build !linux && !freebsd && !darwin
|
||||
//go:build !linux && !freebsd && !darwin && !windows
|
||||
|
||||
package main
|
||||
package cli
|
||||
|
||||
import "github.com/kardianos/service"
|
||||
|
||||
func (p *prog) preRun() {}
|
||||
|
||||
func setDependencies(svc *service.Config) {}
|
||||
|
||||
func setWorkingDirectory(svc *service.Config, dir string) {
|
||||
// WorkingDirectory is not supported on Windows.
|
||||
svc.WorkingDirectory = dir
|
||||
}
|
||||
|
||||
func (p *prog) preStop() {}
|
||||
273
cmd/cli/prog_test.go
Normal file
273
cmd/cli/prog_test.go
Normal file
@@ -0,0 +1,273 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"runtime"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Masterminds/semver/v3"
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld"
|
||||
)
|
||||
|
||||
func Test_prog_dnsWatchdogEnabled(t *testing.T) {
|
||||
p := &prog{cfg: &ctrld.Config{}}
|
||||
|
||||
// Default value is true.
|
||||
assert.True(t, p.dnsWatchdogEnabled())
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
enabled bool
|
||||
}{
|
||||
{"enabled", true},
|
||||
{"disabled", false},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
p.cfg.Service.DnsWatchdogEnabled = &tc.enabled
|
||||
assert.Equal(t, tc.enabled, p.dnsWatchdogEnabled())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_prog_dnsWatchdogInterval(t *testing.T) {
|
||||
p := &prog{cfg: &ctrld.Config{}}
|
||||
|
||||
// Default value is 20s.
|
||||
assert.Equal(t, dnsWatchdogDefaultInterval, p.dnsWatchdogDuration())
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
duration time.Duration
|
||||
expected time.Duration
|
||||
}{
|
||||
{"valid", time.Minute, time.Minute},
|
||||
{"zero", 0, dnsWatchdogDefaultInterval},
|
||||
{"nagative", time.Duration(-1 * time.Minute), dnsWatchdogDefaultInterval},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
p.cfg.Service.DnsWatchdogInvterval = &tc.duration
|
||||
assert.Equal(t, tc.expected, p.dnsWatchdogDuration())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_shouldUpgrade(t *testing.T) {
|
||||
// Helper function to create a version
|
||||
makeVersion := func(v string) *semver.Version {
|
||||
ver, err := semver.NewVersion(v)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create version %s: %v", v, err)
|
||||
}
|
||||
return ver
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
versionTarget string
|
||||
currentVersion *semver.Version
|
||||
shouldUpgrade bool
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "empty version target",
|
||||
versionTarget: "",
|
||||
currentVersion: makeVersion("v1.0.0"),
|
||||
shouldUpgrade: false,
|
||||
description: "should skip upgrade when version target is empty",
|
||||
},
|
||||
{
|
||||
name: "invalid version target",
|
||||
versionTarget: "invalid-version",
|
||||
currentVersion: makeVersion("v1.0.0"),
|
||||
shouldUpgrade: false,
|
||||
description: "should skip upgrade when version target is invalid",
|
||||
},
|
||||
{
|
||||
name: "same version",
|
||||
versionTarget: "v1.0.0",
|
||||
currentVersion: makeVersion("v1.0.0"),
|
||||
shouldUpgrade: false,
|
||||
description: "should skip upgrade when target version equals current version",
|
||||
},
|
||||
{
|
||||
name: "older version",
|
||||
versionTarget: "v1.0.0",
|
||||
currentVersion: makeVersion("v1.1.0"),
|
||||
shouldUpgrade: false,
|
||||
description: "should skip upgrade when target version is older than current version",
|
||||
},
|
||||
{
|
||||
name: "patch upgrade allowed",
|
||||
versionTarget: "v1.0.1",
|
||||
currentVersion: makeVersion("v1.0.0"),
|
||||
shouldUpgrade: true,
|
||||
description: "should allow patch version upgrade within same major version",
|
||||
},
|
||||
{
|
||||
name: "minor upgrade allowed",
|
||||
versionTarget: "v1.1.0",
|
||||
currentVersion: makeVersion("v1.0.0"),
|
||||
shouldUpgrade: true,
|
||||
description: "should allow minor version upgrade within same major version",
|
||||
},
|
||||
{
|
||||
name: "major upgrade blocked",
|
||||
versionTarget: "v2.0.0",
|
||||
currentVersion: makeVersion("v1.0.0"),
|
||||
shouldUpgrade: false,
|
||||
description: "should block major version upgrade",
|
||||
},
|
||||
{
|
||||
name: "major downgrade blocked",
|
||||
versionTarget: "v1.0.0",
|
||||
currentVersion: makeVersion("v2.0.0"),
|
||||
shouldUpgrade: false,
|
||||
description: "should block major version downgrade",
|
||||
},
|
||||
{
|
||||
name: "version without v prefix",
|
||||
versionTarget: "1.0.1",
|
||||
currentVersion: makeVersion("v1.0.0"),
|
||||
shouldUpgrade: true,
|
||||
description: "should handle version target without v prefix",
|
||||
},
|
||||
{
|
||||
name: "complex version upgrade allowed",
|
||||
versionTarget: "v1.5.3",
|
||||
currentVersion: makeVersion("v1.4.2"),
|
||||
shouldUpgrade: true,
|
||||
description: "should allow complex version upgrade within same major version",
|
||||
},
|
||||
{
|
||||
name: "complex major upgrade blocked",
|
||||
versionTarget: "v3.1.0",
|
||||
currentVersion: makeVersion("v2.5.3"),
|
||||
shouldUpgrade: false,
|
||||
description: "should block complex major version upgrade",
|
||||
},
|
||||
{
|
||||
name: "pre-release version upgrade allowed",
|
||||
versionTarget: "v1.0.1-beta.1",
|
||||
currentVersion: makeVersion("v1.0.0"),
|
||||
shouldUpgrade: true,
|
||||
description: "should allow pre-release version upgrade within same major version",
|
||||
},
|
||||
{
|
||||
name: "pre-release major upgrade blocked",
|
||||
versionTarget: "v2.0.0-alpha.1",
|
||||
currentVersion: makeVersion("v1.0.0"),
|
||||
shouldUpgrade: false,
|
||||
description: "should block pre-release major version upgrade",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// Create test logger
|
||||
testLogger := zerolog.New(zerolog.NewTestWriter(t)).With().Logger()
|
||||
|
||||
// Call the function and capture the result
|
||||
result := shouldUpgrade(tc.versionTarget, tc.currentVersion, &testLogger)
|
||||
|
||||
// Assert the expected result
|
||||
assert.Equal(t, tc.shouldUpgrade, result, tc.description)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_selfUpgradeCheck(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("skipped due to Windows file locking issue on Github Action runners")
|
||||
}
|
||||
|
||||
// Helper function to create a version
|
||||
makeVersion := func(v string) *semver.Version {
|
||||
ver, err := semver.NewVersion(v)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create version %s: %v", v, err)
|
||||
}
|
||||
return ver
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
versionTarget string
|
||||
currentVersion *semver.Version
|
||||
shouldUpgrade bool
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "upgrade allowed",
|
||||
versionTarget: "v1.0.1",
|
||||
currentVersion: makeVersion("v1.0.0"),
|
||||
shouldUpgrade: true,
|
||||
description: "should allow upgrade and attempt to perform it",
|
||||
},
|
||||
{
|
||||
name: "upgrade blocked",
|
||||
versionTarget: "v2.0.0",
|
||||
currentVersion: makeVersion("v1.0.0"),
|
||||
shouldUpgrade: false,
|
||||
description: "should block upgrade and not attempt to perform it",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// Create test logger
|
||||
testLogger := zerolog.New(zerolog.NewTestWriter(t)).With().Logger()
|
||||
|
||||
// Call the function and capture the result
|
||||
result := selfUpgradeCheck(tc.versionTarget, tc.currentVersion, &testLogger)
|
||||
|
||||
// Assert the expected result
|
||||
assert.Equal(t, tc.shouldUpgrade, result, tc.description)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_performUpgrade(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("skipped due to Windows file locking issue on Github Action runners")
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
versionTarget string
|
||||
expectedResult bool
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "valid version target",
|
||||
versionTarget: "v1.0.1",
|
||||
expectedResult: true,
|
||||
description: "should attempt to perform upgrade with valid version target",
|
||||
},
|
||||
{
|
||||
name: "empty version target",
|
||||
versionTarget: "",
|
||||
expectedResult: true,
|
||||
description: "should attempt to perform upgrade even with empty version target",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// Call the function and capture the result
|
||||
result := performUpgrade(tc.versionTarget)
|
||||
assert.Equal(t, tc.expectedResult, result, tc.description)
|
||||
})
|
||||
}
|
||||
}
|
||||
14
cmd/cli/prog_windows.go
Normal file
14
cmd/cli/prog_windows.go
Normal file
@@ -0,0 +1,14 @@
|
||||
package cli
|
||||
|
||||
import "github.com/kardianos/service"
|
||||
|
||||
func setDependencies(svc *service.Config) {
|
||||
if hasLocalDnsServerRunning() {
|
||||
svc.Dependencies = []string{"DNS"}
|
||||
}
|
||||
}
|
||||
|
||||
func setWorkingDirectory(svc *service.Config, dir string) {
|
||||
// WorkingDirectory is not supported on Windows.
|
||||
svc.WorkingDirectory = dir
|
||||
}
|
||||
57
cmd/cli/prometheus.go
Normal file
57
cmd/cli/prometheus.go
Normal file
@@ -0,0 +1,57 @@
|
||||
package cli
|
||||
|
||||
import "github.com/prometheus/client_golang/prometheus"
|
||||
|
||||
const (
|
||||
metricsLabelListener = "listener"
|
||||
metricsLabelClientSourceIP = "client_source_ip"
|
||||
metricsLabelClientMac = "client_mac"
|
||||
metricsLabelClientHostname = "client_hostname"
|
||||
metricsLabelUpstream = "upstream"
|
||||
metricsLabelRRType = "rr_type"
|
||||
metricsLabelRCode = "rcode"
|
||||
)
|
||||
|
||||
// statsVersion represent ctrld version.
|
||||
var statsVersion = prometheus.NewCounterVec(prometheus.CounterOpts{
|
||||
Name: "ctrld_build_info",
|
||||
Help: "Version of ctrld process.",
|
||||
}, []string{"gitref", "goversion", "version"})
|
||||
|
||||
// statsTimeStart represents start time of ctrld service.
|
||||
var statsTimeStart = prometheus.NewGauge(prometheus.GaugeOpts{
|
||||
Name: "ctrld_time_seconds",
|
||||
Help: "Start time of the ctrld process since unix epoch in seconds.",
|
||||
})
|
||||
|
||||
var statsQueriesCountLabels = []string{
|
||||
metricsLabelListener,
|
||||
metricsLabelClientSourceIP,
|
||||
metricsLabelClientMac,
|
||||
metricsLabelClientHostname,
|
||||
metricsLabelUpstream,
|
||||
metricsLabelRRType,
|
||||
metricsLabelRCode,
|
||||
}
|
||||
|
||||
// statsQueriesCount counts total number of queries.
|
||||
var statsQueriesCount = prometheus.NewCounterVec(prometheus.CounterOpts{
|
||||
Name: "ctrld_queries_count",
|
||||
Help: "Total number of queries.",
|
||||
}, statsQueriesCountLabels)
|
||||
|
||||
// statsClientQueriesCount counts total number of queries of a client.
|
||||
//
|
||||
// The labels "client_source_ip", "client_mac", "client_hostname" are unbounded,
|
||||
// thus this stat is highly inefficient if there are many devices.
|
||||
var statsClientQueriesCount = prometheus.NewCounterVec(prometheus.CounterOpts{
|
||||
Name: "ctrld_client_queries_count",
|
||||
Help: "Total number queries of a client.",
|
||||
}, []string{metricsLabelClientSourceIP, metricsLabelClientMac, metricsLabelClientHostname})
|
||||
|
||||
// WithLabelValuesInc increases prometheus counter by 1 if query stats is enabled.
|
||||
func (p *prog) WithLabelValuesInc(c *prometheus.CounterVec, lvs ...string) {
|
||||
if p.metricsQueryStats.Load() {
|
||||
c.WithLabelValues(lvs...).Inc()
|
||||
}
|
||||
}
|
||||
17
cmd/cli/reload_others.go
Normal file
17
cmd/cli/reload_others.go
Normal file
@@ -0,0 +1,17 @@
|
||||
//go:build !windows
|
||||
|
||||
package cli
|
||||
|
||||
import (
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
func notifyReloadSigCh(ch chan os.Signal) {
|
||||
signal.Notify(ch, syscall.SIGUSR1)
|
||||
}
|
||||
|
||||
func (p *prog) sendReloadSignal() error {
|
||||
return syscall.Kill(syscall.Getpid(), syscall.SIGUSR1)
|
||||
}
|
||||
18
cmd/cli/reload_windows.go
Normal file
18
cmd/cli/reload_windows.go
Normal file
@@ -0,0 +1,18 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"os"
|
||||
"time"
|
||||
)
|
||||
|
||||
func notifyReloadSigCh(ch chan os.Signal) {}
|
||||
|
||||
func (p *prog) sendReloadSignal() error {
|
||||
select {
|
||||
case p.reloadCh <- struct{}{}:
|
||||
return nil
|
||||
case <-time.After(5 * time.Second):
|
||||
}
|
||||
return errors.New("timeout while sending reload signal")
|
||||
}
|
||||
164
cmd/cli/resolvconf.go
Normal file
164
cmd/cli/resolvconf.go
Normal file
@@ -0,0 +1,164 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/fsnotify/fsnotify"
|
||||
)
|
||||
|
||||
// parseResolvConfNameservers reads the resolv.conf file and returns the nameservers found.
|
||||
// Returns nil if no nameservers are found.
|
||||
func (p *prog) parseResolvConfNameservers(path string) ([]string, error) {
|
||||
content, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Parse the file for "nameserver" lines
|
||||
var currentNS []string
|
||||
lines := strings.Split(string(content), "\n")
|
||||
for _, line := range lines {
|
||||
trimmed := strings.TrimSpace(line)
|
||||
if strings.HasPrefix(trimmed, "nameserver") {
|
||||
parts := strings.Fields(trimmed)
|
||||
if len(parts) >= 2 {
|
||||
currentNS = append(currentNS, parts[1])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return currentNS, nil
|
||||
}
|
||||
|
||||
// watchResolvConf watches any changes to /etc/resolv.conf file,
|
||||
// and reverting to the original config set by ctrld.
|
||||
func (p *prog) watchResolvConf(iface *net.Interface, ns []netip.Addr, setDnsFn func(iface *net.Interface, ns []netip.Addr) error) {
|
||||
resolvConfPath := "/etc/resolv.conf"
|
||||
// Evaluating symbolics link to watch the target file that /etc/resolv.conf point to.
|
||||
if rp, _ := filepath.EvalSymlinks(resolvConfPath); rp != "" {
|
||||
resolvConfPath = rp
|
||||
}
|
||||
mainLog.Load().Debug().Msgf("start watching %s file", resolvConfPath)
|
||||
watcher, err := fsnotify.NewWatcher()
|
||||
if err != nil {
|
||||
mainLog.Load().Warn().Err(err).Msg("could not create watcher for /etc/resolv.conf")
|
||||
return
|
||||
}
|
||||
defer watcher.Close()
|
||||
|
||||
// We watch /etc instead of /etc/resolv.conf directly,
|
||||
// see: https://github.com/fsnotify/fsnotify#watching-a-file-doesnt-work-well
|
||||
watchDir := filepath.Dir(resolvConfPath)
|
||||
if err := watcher.Add(watchDir); err != nil {
|
||||
mainLog.Load().Warn().Err(err).Msgf("could not add %s to watcher list", watchDir)
|
||||
return
|
||||
}
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-p.dnsWatcherStopCh:
|
||||
return
|
||||
case <-p.stopCh:
|
||||
mainLog.Load().Debug().Msgf("stopping watcher for %s", resolvConfPath)
|
||||
return
|
||||
case event, ok := <-watcher.Events:
|
||||
if p.recoveryRunning.Load() {
|
||||
return
|
||||
}
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if event.Name != resolvConfPath { // skip if not /etc/resolv.conf changes.
|
||||
continue
|
||||
}
|
||||
if event.Has(fsnotify.Write) || event.Has(fsnotify.Create) {
|
||||
mainLog.Load().Debug().Msgf("/etc/resolv.conf changes detected, reading changes...")
|
||||
|
||||
// Convert expected nameservers to strings for comparison
|
||||
expectedNS := make([]string, len(ns))
|
||||
for i, addr := range ns {
|
||||
expectedNS[i] = addr.String()
|
||||
}
|
||||
|
||||
var foundNS []string
|
||||
var err error
|
||||
|
||||
maxRetries := 1
|
||||
for retry := 0; retry < maxRetries; retry++ {
|
||||
foundNS, err = p.parseResolvConfNameservers(resolvConfPath)
|
||||
if err != nil {
|
||||
mainLog.Load().Error().Err(err).Msg("failed to read resolv.conf content")
|
||||
break
|
||||
}
|
||||
|
||||
// If we found nameservers, break out of retry loop
|
||||
if len(foundNS) > 0 {
|
||||
break
|
||||
}
|
||||
|
||||
// Only retry if we found no nameservers
|
||||
if retry < maxRetries-1 {
|
||||
mainLog.Load().Debug().Msgf("resolv.conf has no nameserver entries, retry %d/%d in 2 seconds", retry+1, maxRetries)
|
||||
select {
|
||||
case <-p.stopCh:
|
||||
return
|
||||
case <-p.dnsWatcherStopCh:
|
||||
return
|
||||
case <-time.After(2 * time.Second):
|
||||
continue
|
||||
}
|
||||
} else {
|
||||
mainLog.Load().Debug().Msg("resolv.conf remained empty after all retries")
|
||||
}
|
||||
}
|
||||
|
||||
// If we found nameservers, check if they match what we expect
|
||||
if len(foundNS) > 0 {
|
||||
// Check if the nameservers match exactly what we expect
|
||||
matches := len(foundNS) == len(expectedNS)
|
||||
if matches {
|
||||
for i := range foundNS {
|
||||
if foundNS[i] != expectedNS[i] {
|
||||
matches = false
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
mainLog.Load().Debug().
|
||||
Strs("found", foundNS).
|
||||
Strs("expected", expectedNS).
|
||||
Bool("matches", matches).
|
||||
Msg("checking nameservers")
|
||||
|
||||
// Only revert if the nameservers don't match
|
||||
if !matches {
|
||||
if err := watcher.Remove(watchDir); err != nil {
|
||||
mainLog.Load().Error().Err(err).Msg("failed to pause watcher")
|
||||
continue
|
||||
}
|
||||
|
||||
if err := setDnsFn(iface, ns); err != nil {
|
||||
mainLog.Load().Error().Err(err).Msg("failed to revert /etc/resolv.conf changes")
|
||||
}
|
||||
|
||||
if err := watcher.Add(watchDir); err != nil {
|
||||
mainLog.Load().Error().Err(err).Msg("failed to continue running watcher")
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
case err, ok := <-watcher.Errors:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
mainLog.Load().Err(err).Msg("could not get event for /etc/resolv.conf")
|
||||
}
|
||||
}
|
||||
}
|
||||
49
cmd/cli/resolvconf_darwin.go
Normal file
49
cmd/cli/resolvconf_darwin.go
Normal file
@@ -0,0 +1,49 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"slices"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld/internal/dns/resolvconffile"
|
||||
)
|
||||
|
||||
const resolvConfPath = "/etc/resolv.conf"
|
||||
|
||||
// setResolvConf sets the content of resolv.conf file using the given nameservers list.
|
||||
func setResolvConf(iface *net.Interface, ns []netip.Addr) error {
|
||||
servers := make([]string, len(ns))
|
||||
for i := range ns {
|
||||
servers[i] = ns[i].String()
|
||||
}
|
||||
if err := setDNS(iface, servers); err != nil {
|
||||
return err
|
||||
}
|
||||
slices.Sort(servers)
|
||||
curNs := currentDNS(iface)
|
||||
slices.Sort(curNs)
|
||||
if !slices.Equal(curNs, servers) {
|
||||
c, err := resolvconffile.ParseFile(resolvConfPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.Nameservers = ns
|
||||
f, err := os.Create(resolvConfPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
if err := c.Write(f); err != nil {
|
||||
return err
|
||||
}
|
||||
return f.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// shouldWatchResolvconf reports whether ctrld should watch changes to resolv.conf file with given OS configurator.
|
||||
func shouldWatchResolvconf() bool {
|
||||
return true
|
||||
}
|
||||
52
cmd/cli/resolvconf_not_darwin_unix.go
Normal file
52
cmd/cli/resolvconf_not_darwin_unix.go
Normal file
@@ -0,0 +1,52 @@
|
||||
//go:build unix && !darwin
|
||||
|
||||
package cli
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
|
||||
"tailscale.com/control/controlknobs"
|
||||
"tailscale.com/health"
|
||||
"tailscale.com/util/dnsname"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld/internal/dns"
|
||||
)
|
||||
|
||||
// setResolvConf sets the content of the resolv.conf file using the given nameservers list.
|
||||
func setResolvConf(iface *net.Interface, ns []netip.Addr) error {
|
||||
r, err := newLoopbackOSConfigurator()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
oc := dns.OSConfig{
|
||||
Nameservers: ns,
|
||||
SearchDomains: []dnsname.FQDN{},
|
||||
}
|
||||
if sds, err := searchDomains(); err == nil {
|
||||
oc.SearchDomains = sds
|
||||
} else {
|
||||
mainLog.Load().Debug().Err(err).Msg("failed to get search domains list when reverting resolv.conf file")
|
||||
}
|
||||
return r.SetDNS(oc)
|
||||
}
|
||||
|
||||
// shouldWatchResolvconf reports whether ctrld should watch changes to resolv.conf file with given OS configurator.
|
||||
func shouldWatchResolvconf() bool {
|
||||
r, err := newLoopbackOSConfigurator()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
switch r.Mode() {
|
||||
case "direct", "resolvconf":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// newLoopbackOSConfigurator creates an OSConfigurator for DNS management using the "lo" interface.
|
||||
func newLoopbackOSConfigurator() (dns.OSConfigurator, error) {
|
||||
return dns.NewOSConfigurator(noopLogf, &health.Tracker{}, &controlknobs.Knobs{}, "lo")
|
||||
}
|
||||
16
cmd/cli/resolvconf_windows.go
Normal file
16
cmd/cli/resolvconf_windows.go
Normal file
@@ -0,0 +1,16 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
)
|
||||
|
||||
// setResolvConf sets the content of resolv.conf file using the given nameservers list.
|
||||
func setResolvConf(_ *net.Interface, _ []netip.Addr) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// shouldWatchResolvconf reports whether ctrld should watch changes to resolv.conf file with given OS configurator.
|
||||
func shouldWatchResolvconf() bool {
|
||||
return false
|
||||
}
|
||||
14
cmd/cli/search_domains_unix.go
Normal file
14
cmd/cli/search_domains_unix.go
Normal file
@@ -0,0 +1,14 @@
|
||||
//go:build unix
|
||||
|
||||
package cli
|
||||
|
||||
import (
|
||||
"tailscale.com/util/dnsname"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld/internal/resolvconffile"
|
||||
)
|
||||
|
||||
// searchDomains returns the current search domains config.
|
||||
func searchDomains() ([]dnsname.FQDN, error) {
|
||||
return resolvconffile.SearchDomains()
|
||||
}
|
||||
43
cmd/cli/search_domains_windows.go
Normal file
43
cmd/cli/search_domains_windows.go
Normal file
@@ -0,0 +1,43 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"syscall"
|
||||
|
||||
"golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
|
||||
"tailscale.com/util/dnsname"
|
||||
)
|
||||
|
||||
// searchDomains returns the current search domains config.
|
||||
func searchDomains() ([]dnsname.FQDN, error) {
|
||||
flags := winipcfg.GAAFlagIncludeGateways |
|
||||
winipcfg.GAAFlagIncludePrefix
|
||||
|
||||
aas, err := winipcfg.GetAdaptersAddresses(syscall.AF_UNSPEC, flags)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("winipcfg.GetAdaptersAddresses: %w", err)
|
||||
}
|
||||
|
||||
var sds []dnsname.FQDN
|
||||
for _, aa := range aas {
|
||||
if aa.OperStatus != winipcfg.IfOperStatusUp {
|
||||
continue
|
||||
}
|
||||
|
||||
// Skip if software loopback or other non-physical types
|
||||
// This is to avoid the "Loopback Pseudo-Interface 1" issue we see on windows
|
||||
if aa.IfType == winipcfg.IfTypeSoftwareLoopback {
|
||||
continue
|
||||
}
|
||||
|
||||
for a := aa.FirstDNSSuffix; a != nil; a = a.Next {
|
||||
d, err := dnsname.ToFQDN(a.String())
|
||||
if err != nil {
|
||||
mainLog.Load().Debug().Err(err).Msgf("failed to parse domain: %s", a.String())
|
||||
continue
|
||||
}
|
||||
sds = append(sds, d)
|
||||
}
|
||||
}
|
||||
return sds, nil
|
||||
}
|
||||
7
cmd/cli/self_delete_others.go
Normal file
7
cmd/cli/self_delete_others.go
Normal file
@@ -0,0 +1,7 @@
|
||||
//go:build !windows
|
||||
|
||||
package cli
|
||||
|
||||
var supportedSelfDelete = true
|
||||
|
||||
func selfDeleteExe() error { return nil }
|
||||
134
cmd/cli/self_delete_windows.go
Normal file
134
cmd/cli/self_delete_windows.go
Normal file
@@ -0,0 +1,134 @@
|
||||
// Copied from https://github.com/secur30nly/go-self-delete
|
||||
// with modification to suitable for ctrld usage.
|
||||
|
||||
/*
|
||||
License: MIT Licence
|
||||
|
||||
References:
|
||||
- https://github.com/LloydLabs/delete-self-poc
|
||||
- https://twitter.com/jonasLyk/status/1350401461985955840
|
||||
*/
|
||||
|
||||
package cli
|
||||
|
||||
import (
|
||||
"unsafe"
|
||||
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
var supportedSelfDelete = false
|
||||
|
||||
type FILE_RENAME_INFO struct {
|
||||
Union struct {
|
||||
ReplaceIfExists bool
|
||||
Flags uint32
|
||||
}
|
||||
RootDirectory windows.Handle
|
||||
FileNameLength uint32
|
||||
FileName [1]uint16
|
||||
}
|
||||
|
||||
type FILE_DISPOSITION_INFO struct {
|
||||
DeleteFile bool
|
||||
}
|
||||
|
||||
func dsOpenHandle(pwPath *uint16) (windows.Handle, error) {
|
||||
handle, err := windows.CreateFile(
|
||||
pwPath,
|
||||
windows.DELETE,
|
||||
0,
|
||||
nil,
|
||||
windows.OPEN_EXISTING,
|
||||
windows.FILE_ATTRIBUTE_NORMAL,
|
||||
0,
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return handle, nil
|
||||
}
|
||||
|
||||
func dsRenameHandle(hHandle windows.Handle) error {
|
||||
var fRename FILE_RENAME_INFO
|
||||
DS_STREAM_RENAME, err := windows.UTF16FromString(":deadbeef")
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
lpwStream := &DS_STREAM_RENAME[0]
|
||||
fRename.FileNameLength = uint32(unsafe.Sizeof(lpwStream))
|
||||
|
||||
windows.NewLazyDLL("kernel32.dll").NewProc("RtlCopyMemory").Call(
|
||||
uintptr(unsafe.Pointer(&fRename.FileName[0])),
|
||||
uintptr(unsafe.Pointer(lpwStream)),
|
||||
unsafe.Sizeof(lpwStream),
|
||||
)
|
||||
|
||||
err = windows.SetFileInformationByHandle(
|
||||
hHandle,
|
||||
windows.FileRenameInfo,
|
||||
(*byte)(unsafe.Pointer(&fRename)),
|
||||
uint32(unsafe.Sizeof(fRename)+unsafe.Sizeof(lpwStream)),
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func dsDepositeHandle(hHandle windows.Handle) error {
|
||||
var fDelete FILE_DISPOSITION_INFO
|
||||
fDelete.DeleteFile = true
|
||||
|
||||
err := windows.SetFileInformationByHandle(
|
||||
hHandle,
|
||||
windows.FileDispositionInfo,
|
||||
(*byte)(unsafe.Pointer(&fDelete)),
|
||||
uint32(unsafe.Sizeof(fDelete)),
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func selfDeleteExe() error {
|
||||
var wcPath [windows.MAX_PATH + 1]uint16
|
||||
var hCurrent windows.Handle
|
||||
|
||||
_, err := windows.GetModuleFileName(0, &wcPath[0], windows.MAX_PATH)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
hCurrent, err = dsOpenHandle(&wcPath[0])
|
||||
if err != nil || hCurrent == windows.InvalidHandle {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := dsRenameHandle(hCurrent); err != nil {
|
||||
_ = windows.CloseHandle(hCurrent)
|
||||
return err
|
||||
}
|
||||
_ = windows.CloseHandle(hCurrent)
|
||||
|
||||
hCurrent, err = dsOpenHandle(&wcPath[0])
|
||||
if err != nil || hCurrent == windows.InvalidHandle {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := dsDepositeHandle(hCurrent); err != nil {
|
||||
_ = windows.CloseHandle(hCurrent)
|
||||
return err
|
||||
}
|
||||
|
||||
return windows.CloseHandle(hCurrent)
|
||||
}
|
||||
16
cmd/cli/self_kill_others.go
Normal file
16
cmd/cli/self_kill_others.go
Normal file
@@ -0,0 +1,16 @@
|
||||
//go:build !unix
|
||||
|
||||
package cli
|
||||
|
||||
import (
|
||||
"os"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
)
|
||||
|
||||
func selfUninstall(p *prog, logger zerolog.Logger) {
|
||||
if uninstallInvalidCdUID(p, logger, false) {
|
||||
logger.Warn().Msgf("service was uninstalled because device %q does not exist", cdUID)
|
||||
os.Exit(0)
|
||||
}
|
||||
}
|
||||
45
cmd/cli/self_kill_unix.go
Normal file
45
cmd/cli/self_kill_unix.go
Normal file
@@ -0,0 +1,45 @@
|
||||
//go:build unix
|
||||
|
||||
package cli
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"runtime"
|
||||
"syscall"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
)
|
||||
|
||||
func selfUninstall(p *prog, logger zerolog.Logger) {
|
||||
if runtime.GOOS == "linux" {
|
||||
selfUninstallLinux(p, logger)
|
||||
}
|
||||
|
||||
bin, err := os.Executable()
|
||||
if err != nil {
|
||||
logger.Fatal().Err(err).Msg("could not determine executable")
|
||||
}
|
||||
args := []string{"uninstall"}
|
||||
if deactivationPinSet() {
|
||||
args = append(args, fmt.Sprintf("--pin=%d", cdDeactivationPin.Load()))
|
||||
}
|
||||
cmd := exec.Command(bin, args...)
|
||||
cmd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true}
|
||||
if err := cmd.Start(); err != nil {
|
||||
logger.Fatal().Err(err).Msg("could not start self uninstall command")
|
||||
}
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
logger.Warn().Msgf("service was uninstalled because device %q does not exist", cdUID)
|
||||
_ = cmd.Wait()
|
||||
os.Exit(0)
|
||||
}
|
||||
|
||||
func selfUninstallLinux(p *prog, logger zerolog.Logger) {
|
||||
if uninstallInvalidCdUID(p, logger, true) {
|
||||
logger.Warn().Msgf("service was uninstalled because device %q does not exist", cdUID)
|
||||
os.Exit(0)
|
||||
}
|
||||
}
|
||||
12
cmd/cli/self_upgrade_others.go
Normal file
12
cmd/cli/self_upgrade_others.go
Normal file
@@ -0,0 +1,12 @@
|
||||
//go:build !windows
|
||||
|
||||
package cli
|
||||
|
||||
import (
|
||||
"syscall"
|
||||
)
|
||||
|
||||
// sysProcAttrForDetachedChildProcess returns *syscall.SysProcAttr instance for running a detached child command.
|
||||
func sysProcAttrForDetachedChildProcess() *syscall.SysProcAttr {
|
||||
return &syscall.SysProcAttr{Setsid: true}
|
||||
}
|
||||
18
cmd/cli/self_upgrade_windows.go
Normal file
18
cmd/cli/self_upgrade_windows.go
Normal file
@@ -0,0 +1,18 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"syscall"
|
||||
)
|
||||
|
||||
// From: https://learn.microsoft.com/en-us/windows/win32/procthread/process-creation-flags?redirectedfrom=MSDN
|
||||
|
||||
// SYSCALL_CREATE_NO_WINDOW set flag to run process without a console window.
|
||||
const SYSCALL_CREATE_NO_WINDOW = 0x08000000
|
||||
|
||||
// sysProcAttrForDetachedChildProcess returns *syscall.SysProcAttr instance for running self-upgrade command.
|
||||
func sysProcAttrForDetachedChildProcess() *syscall.SysProcAttr {
|
||||
return &syscall.SysProcAttr{
|
||||
CreationFlags: syscall.CREATE_NEW_PROCESS_GROUP | SYSCALL_CREATE_NO_WINDOW,
|
||||
HideWindow: true,
|
||||
}
|
||||
}
|
||||
24
cmd/cli/sema.go
Normal file
24
cmd/cli/sema.go
Normal file
@@ -0,0 +1,24 @@
|
||||
package cli
|
||||
|
||||
type semaphore interface {
|
||||
acquire()
|
||||
release()
|
||||
}
|
||||
|
||||
type noopSemaphore struct{}
|
||||
|
||||
func (n noopSemaphore) acquire() {}
|
||||
|
||||
func (n noopSemaphore) release() {}
|
||||
|
||||
type chanSemaphore struct {
|
||||
ready chan struct{}
|
||||
}
|
||||
|
||||
func (c *chanSemaphore) acquire() {
|
||||
c.ready <- struct{}{}
|
||||
}
|
||||
|
||||
func (c *chanSemaphore) release() {
|
||||
<-c.ready
|
||||
}
|
||||
268
cmd/cli/service.go
Normal file
268
cmd/cli/service.go
Normal file
@@ -0,0 +1,268 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"os/exec"
|
||||
"runtime"
|
||||
|
||||
"github.com/coreos/go-systemd/v22/unit"
|
||||
"github.com/kardianos/service"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld/internal/router"
|
||||
"github.com/Control-D-Inc/ctrld/internal/router/openwrt"
|
||||
)
|
||||
|
||||
// newService wraps service.New call to return service.Service
|
||||
// wrapper which is suitable for the current platform.
|
||||
func newService(i service.Interface, c *service.Config) (service.Service, error) {
|
||||
s, err := service.New(i, c)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
switch {
|
||||
case router.IsOldOpenwrt(), router.IsNetGearOrbi():
|
||||
return &procd{sysV: &sysV{s}, svcConfig: c}, nil
|
||||
case router.IsGLiNet():
|
||||
return &sysV{s}, nil
|
||||
case s.Platform() == "unix-systemv":
|
||||
return &sysV{s}, nil
|
||||
case s.Platform() == "linux-systemd":
|
||||
return &systemd{s}, nil
|
||||
case s.Platform() == "darwin-launchd":
|
||||
return newLaunchd(s), nil
|
||||
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// sysV wraps a service.Service, and provide start/stop/status command
|
||||
// base on "/etc/init.d/<service_name>".
|
||||
//
|
||||
// Use this on system where "service" command is not available, like GL.iNET router.
|
||||
type sysV struct {
|
||||
service.Service
|
||||
}
|
||||
|
||||
func (s *sysV) installed() bool {
|
||||
fi, err := os.Stat("/etc/init.d/ctrld")
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
mode := fi.Mode()
|
||||
return mode.IsRegular() && (mode&0111) != 0
|
||||
}
|
||||
|
||||
func (s *sysV) Start() error {
|
||||
if !s.installed() {
|
||||
return service.ErrNotInstalled
|
||||
}
|
||||
_, err := exec.Command("/etc/init.d/ctrld", "start").CombinedOutput()
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *sysV) Stop() error {
|
||||
if !s.installed() {
|
||||
return service.ErrNotInstalled
|
||||
}
|
||||
_, err := exec.Command("/etc/init.d/ctrld", "stop").CombinedOutput()
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *sysV) Restart() error {
|
||||
if !s.installed() {
|
||||
return service.ErrNotInstalled
|
||||
}
|
||||
// We don't care about error returned by s.Stop,
|
||||
// because the service may already be stopped.
|
||||
_ = s.Stop()
|
||||
return s.Start()
|
||||
}
|
||||
|
||||
func (s *sysV) Status() (service.Status, error) {
|
||||
if !s.installed() {
|
||||
return service.StatusUnknown, service.ErrNotInstalled
|
||||
}
|
||||
return unixSystemVServiceStatus()
|
||||
}
|
||||
|
||||
// procd wraps a service.Service, and provide start/stop command
|
||||
// base on "/etc/init.d/<service_name>", status command base on parsing "ps" command output.
|
||||
//
|
||||
// Use this on system where "/etc/init.d/<service_name> status" command is not available,
|
||||
// like old GL.iNET Opal router.
|
||||
type procd struct {
|
||||
*sysV
|
||||
svcConfig *service.Config
|
||||
}
|
||||
|
||||
func (s *procd) Status() (service.Status, error) {
|
||||
if !s.installed() {
|
||||
return service.StatusUnknown, service.ErrNotInstalled
|
||||
}
|
||||
bin := s.svcConfig.Executable
|
||||
if bin == "" {
|
||||
exe, err := os.Executable()
|
||||
if err != nil {
|
||||
return service.StatusUnknown, nil
|
||||
}
|
||||
bin = exe
|
||||
}
|
||||
|
||||
// Looking for something like "/sbin/ctrld run ".
|
||||
shellCmd := fmt.Sprintf("ps | grep -q %q", bin+" [r]un ")
|
||||
if err := exec.Command("sh", "-c", shellCmd).Run(); err != nil {
|
||||
return service.StatusStopped, nil
|
||||
}
|
||||
return service.StatusRunning, nil
|
||||
}
|
||||
|
||||
// systemd wraps a service.Service, and provide status command to
|
||||
// report the status correctly.
|
||||
type systemd struct {
|
||||
service.Service
|
||||
}
|
||||
|
||||
func (s *systemd) Status() (service.Status, error) {
|
||||
out, _ := exec.Command("systemctl", "status", "ctrld").CombinedOutput()
|
||||
if bytes.Contains(out, []byte("/FAILURE)")) {
|
||||
return service.StatusStopped, nil
|
||||
}
|
||||
return s.Service.Status()
|
||||
}
|
||||
|
||||
func (s *systemd) Start() error {
|
||||
const systemdUnitFile = "/etc/systemd/system/ctrld.service"
|
||||
f, err := os.Open(systemdUnitFile)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer f.Close()
|
||||
if opts, change := ensureSystemdKillMode(f); change {
|
||||
mode := os.FileMode(0644)
|
||||
buf, err := io.ReadAll(unit.Serialize(opts))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := os.WriteFile(systemdUnitFile, buf, mode); err != nil {
|
||||
return err
|
||||
}
|
||||
if out, err := exec.Command("systemctl", "daemon-reload").CombinedOutput(); err != nil {
|
||||
return fmt.Errorf("systemctl daemon-reload failed: %w\n%s", err, string(out))
|
||||
}
|
||||
mainLog.Load().Debug().Msg("set KillMode=process successfully")
|
||||
}
|
||||
return s.Service.Start()
|
||||
}
|
||||
|
||||
// ensureSystemdKillMode ensure systemd unit file is configured with KillMode=process.
|
||||
// This is necessary for running self-upgrade flow.
|
||||
func ensureSystemdKillMode(r io.Reader) (opts []*unit.UnitOption, change bool) {
|
||||
opts, err := unit.DeserializeOptions(r)
|
||||
if err != nil {
|
||||
mainLog.Load().Error().Err(err).Msg("failed to deserialize options")
|
||||
return
|
||||
}
|
||||
change = true
|
||||
needKillModeOpt := true
|
||||
killModeOpt := unit.NewUnitOption("Service", "KillMode", "process")
|
||||
for _, opt := range opts {
|
||||
if opt.Match(killModeOpt) {
|
||||
needKillModeOpt = false
|
||||
change = false
|
||||
break
|
||||
}
|
||||
if opt.Section == killModeOpt.Section && opt.Name == killModeOpt.Name {
|
||||
opt.Value = killModeOpt.Value
|
||||
needKillModeOpt = false
|
||||
break
|
||||
}
|
||||
}
|
||||
if needKillModeOpt {
|
||||
opts = append(opts, killModeOpt)
|
||||
}
|
||||
return opts, change
|
||||
}
|
||||
|
||||
func newLaunchd(s service.Service) *launchd {
|
||||
return &launchd{
|
||||
Service: s,
|
||||
statusErrMsg: "Permission denied",
|
||||
}
|
||||
}
|
||||
|
||||
// launchd wraps a service.Service, and provide status command to
|
||||
// report the status correctly when not running as root on Darwin.
|
||||
//
|
||||
// TODO: remove this wrapper once https://github.com/kardianos/service/issues/400 fixed.
|
||||
type launchd struct {
|
||||
service.Service
|
||||
statusErrMsg string
|
||||
}
|
||||
|
||||
func (l *launchd) Status() (service.Status, error) {
|
||||
if os.Geteuid() != 0 {
|
||||
return service.StatusUnknown, errors.New(l.statusErrMsg)
|
||||
}
|
||||
return l.Service.Status()
|
||||
}
|
||||
|
||||
type task struct {
|
||||
f func() error
|
||||
abortOnError bool
|
||||
Name string
|
||||
}
|
||||
|
||||
func doTasks(tasks []task) bool {
|
||||
for _, task := range tasks {
|
||||
mainLog.Load().Debug().Msgf("Running task %s", task.Name)
|
||||
if err := task.f(); err != nil {
|
||||
if task.abortOnError {
|
||||
mainLog.Load().Error().Msgf("error running task %s: %v", task.Name, err)
|
||||
return false
|
||||
}
|
||||
// if this is darwin stop command, dont print debug
|
||||
// since launchctl complains on every start
|
||||
if runtime.GOOS != "darwin" || task.Name != "Stop" {
|
||||
mainLog.Load().Debug().Msgf("error running task %s: %v", task.Name, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func checkHasElevatedPrivilege() {
|
||||
ok, err := hasElevatedPrivilege()
|
||||
if err != nil {
|
||||
mainLog.Load().Error().Msgf("could not detect user privilege: %v", err)
|
||||
return
|
||||
}
|
||||
if !ok {
|
||||
mainLog.Load().Error().Msg("Please relaunch process with admin/root privilege.")
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
func unixSystemVServiceStatus() (service.Status, error) {
|
||||
out, err := exec.Command("/etc/init.d/ctrld", "status").CombinedOutput()
|
||||
if err != nil {
|
||||
// Specific case for openwrt >= 24.10, it returns non-success code
|
||||
// for above status command, which may not right.
|
||||
if router.Name() == openwrt.Name {
|
||||
if string(bytes.ToLower(bytes.TrimSpace(out))) == "inactive" {
|
||||
return service.StatusStopped, nil
|
||||
}
|
||||
}
|
||||
return service.StatusUnknown, nil
|
||||
}
|
||||
|
||||
switch string(bytes.ToLower(bytes.TrimSpace(out))) {
|
||||
case "running":
|
||||
return service.StatusRunning, nil
|
||||
default:
|
||||
return service.StatusStopped, nil
|
||||
}
|
||||
}
|
||||
134
cmd/cli/service_args_darwin.go
Normal file
134
cmd/cli/service_args_darwin.go
Normal file
@@ -0,0 +1,134 @@
|
||||
//go:build darwin
|
||||
|
||||
package cli
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const launchdPlistPath = "/Library/LaunchDaemons/ctrld.plist"
|
||||
|
||||
// serviceConfigFileExists returns true if the launchd plist for ctrld exists on disk.
|
||||
// This is more reliable than checking launchctl status, which may report "not found"
|
||||
// if the service was unloaded but the plist file still exists.
|
||||
func serviceConfigFileExists() bool {
|
||||
_, err := os.Stat(launchdPlistPath)
|
||||
return err == nil
|
||||
}
|
||||
|
||||
// appendServiceFlag appends a CLI flag (e.g., "--intercept-mode") to the installed
|
||||
// service's launch arguments. This is used when upgrading an existing installation
|
||||
// to intercept mode without losing the existing --cd flag and other arguments.
|
||||
//
|
||||
// On macOS, this modifies the launchd plist at /Library/LaunchDaemons/ctrld.plist
|
||||
// using the "defaults" command, which is the standard way to edit plists.
|
||||
//
|
||||
// The function is idempotent: if the flag already exists, it's a no-op.
|
||||
func appendServiceFlag(flag string) error {
|
||||
// Read current ProgramArguments from plist.
|
||||
out, err := exec.Command("defaults", "read", launchdPlistPath, "ProgramArguments").CombinedOutput()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read plist ProgramArguments: %w (output: %s)", err, strings.TrimSpace(string(out)))
|
||||
}
|
||||
|
||||
// Check if the flag is already present (idempotent).
|
||||
args := string(out)
|
||||
if strings.Contains(args, flag) {
|
||||
mainLog.Load().Debug().Msgf("Service flag %q already present in plist, skipping", flag)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Use PlistBuddy to append the flag to ProgramArguments array.
|
||||
// PlistBuddy is more reliable than "defaults" for array manipulation.
|
||||
addCmd := exec.Command(
|
||||
"/usr/libexec/PlistBuddy",
|
||||
"-c", fmt.Sprintf("Add :ProgramArguments: string %s", flag),
|
||||
launchdPlistPath,
|
||||
)
|
||||
if out, err := addCmd.CombinedOutput(); err != nil {
|
||||
return fmt.Errorf("failed to append %q to plist ProgramArguments: %w (output: %s)", flag, err, strings.TrimSpace(string(out)))
|
||||
}
|
||||
|
||||
mainLog.Load().Info().Msgf("Appended %q to service launch arguments", flag)
|
||||
return nil
|
||||
}
|
||||
|
||||
// verifyServiceRegistration is a no-op on macOS (launchd plist verification not needed).
|
||||
func verifyServiceRegistration() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// removeServiceFlag removes a CLI flag (and its value, if the next argument is not
|
||||
// a flag) from the installed service's launch arguments. For example, removing
|
||||
// "--intercept-mode" also removes the following "dns" or "hard" value argument.
|
||||
//
|
||||
// The function is idempotent: if the flag doesn't exist, it's a no-op.
|
||||
func removeServiceFlag(flag string) error {
|
||||
// Read current ProgramArguments to find the index.
|
||||
out, err := exec.Command("/usr/libexec/PlistBuddy", "-c", "Print :ProgramArguments", launchdPlistPath).CombinedOutput()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read plist ProgramArguments: %w (output: %s)", err, strings.TrimSpace(string(out)))
|
||||
}
|
||||
|
||||
// Parse the PlistBuddy output to find the flag's index.
|
||||
// PlistBuddy prints arrays as:
|
||||
// Array {
|
||||
// /path/to/ctrld
|
||||
// run
|
||||
// --cd=xxx
|
||||
// --intercept-mode
|
||||
// dns
|
||||
// }
|
||||
lines := strings.Split(string(out), "\n")
|
||||
var entries []string
|
||||
for _, line := range lines {
|
||||
trimmed := strings.TrimSpace(line)
|
||||
if trimmed == "Array {" || trimmed == "}" || trimmed == "" {
|
||||
continue
|
||||
}
|
||||
entries = append(entries, trimmed)
|
||||
}
|
||||
|
||||
index := -1
|
||||
for i, entry := range entries {
|
||||
if entry == flag {
|
||||
index = i
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if index < 0 {
|
||||
mainLog.Load().Debug().Msgf("Service flag %q not present in plist, skipping removal", flag)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check if the next entry is a value (not a flag). If so, delete it first
|
||||
// (deleting by index shifts subsequent entries down, so delete value before flag).
|
||||
hasValue := index+1 < len(entries) && !strings.HasPrefix(entries[index+1], "-")
|
||||
if hasValue {
|
||||
delVal := exec.Command(
|
||||
"/usr/libexec/PlistBuddy",
|
||||
"-c", fmt.Sprintf("Delete :ProgramArguments:%d", index+1),
|
||||
launchdPlistPath,
|
||||
)
|
||||
if out, err := delVal.CombinedOutput(); err != nil {
|
||||
return fmt.Errorf("failed to remove value for %q from plist: %w (output: %s)", flag, err, strings.TrimSpace(string(out)))
|
||||
}
|
||||
}
|
||||
|
||||
// Delete the flag itself.
|
||||
delCmd := exec.Command(
|
||||
"/usr/libexec/PlistBuddy",
|
||||
"-c", fmt.Sprintf("Delete :ProgramArguments:%d", index),
|
||||
launchdPlistPath,
|
||||
)
|
||||
if out, err := delCmd.CombinedOutput(); err != nil {
|
||||
return fmt.Errorf("failed to remove %q from plist ProgramArguments: %w (output: %s)", flag, err, strings.TrimSpace(string(out)))
|
||||
}
|
||||
|
||||
mainLog.Load().Info().Msgf("Removed %q from service launch arguments", flag)
|
||||
return nil
|
||||
}
|
||||
38
cmd/cli/service_args_others.go
Normal file
38
cmd/cli/service_args_others.go
Normal file
@@ -0,0 +1,38 @@
|
||||
//go:build !darwin && !windows
|
||||
|
||||
package cli
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
)
|
||||
|
||||
// serviceConfigFileExists checks common service config file locations on Linux.
|
||||
func serviceConfigFileExists() bool {
|
||||
// systemd unit file
|
||||
if _, err := os.Stat("/etc/systemd/system/ctrld.service"); err == nil {
|
||||
return true
|
||||
}
|
||||
// SysV init script
|
||||
if _, err := os.Stat("/etc/init.d/ctrld"); err == nil {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// appendServiceFlag is not yet implemented on this platform.
|
||||
// Linux services (systemd) store args in unit files; intercept mode
|
||||
// should be set via the config file (intercept_mode) on these platforms.
|
||||
func appendServiceFlag(flag string) error {
|
||||
return fmt.Errorf("appending service flags is not supported on this platform; use intercept_mode in config instead")
|
||||
}
|
||||
|
||||
// verifyServiceRegistration is a no-op on this platform.
|
||||
func verifyServiceRegistration() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// removeServiceFlag is not yet implemented on this platform.
|
||||
func removeServiceFlag(flag string) error {
|
||||
return fmt.Errorf("removing service flags is not supported on this platform; use intercept_mode in config instead")
|
||||
}
|
||||
153
cmd/cli/service_args_windows.go
Normal file
153
cmd/cli/service_args_windows.go
Normal file
@@ -0,0 +1,153 @@
|
||||
//go:build windows
|
||||
|
||||
package cli
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"golang.org/x/sys/windows/svc/mgr"
|
||||
)
|
||||
|
||||
// serviceConfigFileExists returns true if the ctrld Windows service is registered.
|
||||
func serviceConfigFileExists() bool {
|
||||
m, err := mgr.Connect()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
defer m.Disconnect()
|
||||
s, err := m.OpenService(ctrldServiceName)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
s.Close()
|
||||
return true
|
||||
}
|
||||
|
||||
// appendServiceFlag appends a CLI flag (e.g., "--intercept-mode") to the installed
|
||||
// Windows service's BinPath arguments. This is used when upgrading an existing
|
||||
// installation to intercept mode without losing the existing --cd flag.
|
||||
//
|
||||
// The function is idempotent: if the flag already exists, it's a no-op.
|
||||
func appendServiceFlag(flag string) error {
|
||||
m, err := mgr.Connect()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to connect to Windows SCM: %w", err)
|
||||
}
|
||||
defer m.Disconnect()
|
||||
|
||||
s, err := m.OpenService(ctrldServiceName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to open service %q: %w", ctrldServiceName, err)
|
||||
}
|
||||
defer s.Close()
|
||||
|
||||
config, err := s.Config()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read service config: %w", err)
|
||||
}
|
||||
|
||||
// Check if flag already present (idempotent).
|
||||
if strings.Contains(config.BinaryPathName, flag) {
|
||||
mainLog.Load().Debug().Msgf("Service flag %q already present in BinPath, skipping", flag)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Append the flag to BinPath.
|
||||
config.BinaryPathName = strings.TrimSpace(config.BinaryPathName) + " " + flag
|
||||
|
||||
if err := s.UpdateConfig(config); err != nil {
|
||||
return fmt.Errorf("failed to update service config with %q: %w", flag, err)
|
||||
}
|
||||
|
||||
mainLog.Load().Info().Msgf("Appended %q to service BinPath", flag)
|
||||
return nil
|
||||
}
|
||||
|
||||
// verifyServiceRegistration opens the Windows Service Control Manager and verifies
|
||||
// that the ctrld service is correctly registered: logs the BinaryPathName, checks
|
||||
// that --intercept-mode is present if expected, and verifies SERVICE_AUTO_START.
|
||||
func verifyServiceRegistration() error {
|
||||
m, err := mgr.Connect()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to connect to Windows SCM: %w", err)
|
||||
}
|
||||
defer m.Disconnect()
|
||||
|
||||
s, err := m.OpenService(ctrldServiceName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to open service %q: %w", ctrldServiceName, err)
|
||||
}
|
||||
defer s.Close()
|
||||
|
||||
config, err := s.Config()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read service config: %w", err)
|
||||
}
|
||||
|
||||
mainLog.Load().Debug().Msgf("Service registry: BinaryPathName = %q", config.BinaryPathName)
|
||||
|
||||
// If intercept mode is set, verify the flag is present in BinPath.
|
||||
if interceptMode == "dns" || interceptMode == "hard" {
|
||||
if !strings.Contains(config.BinaryPathName, "--intercept-mode") {
|
||||
return fmt.Errorf("service registry: --intercept-mode flag missing from BinaryPathName (expected mode %q)", interceptMode)
|
||||
}
|
||||
mainLog.Load().Debug().Msgf("Service registry: --intercept-mode flag present in BinaryPathName")
|
||||
}
|
||||
|
||||
// Verify auto-start. mgr.StartAutomatic == 2 == SERVICE_AUTO_START.
|
||||
if config.StartType != mgr.StartAutomatic {
|
||||
return fmt.Errorf("service registry: StartType is %d, expected SERVICE_AUTO_START (%d)", config.StartType, mgr.StartAutomatic)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// removeServiceFlag removes a CLI flag (and its value, if present) from the installed
|
||||
// Windows service's BinPath. For example, removing "--intercept-mode" also removes
|
||||
// the following "dns" or "hard" value. The function is idempotent.
|
||||
func removeServiceFlag(flag string) error {
|
||||
m, err := mgr.Connect()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to connect to Windows SCM: %w", err)
|
||||
}
|
||||
defer m.Disconnect()
|
||||
|
||||
s, err := m.OpenService(ctrldServiceName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to open service %q: %w", ctrldServiceName, err)
|
||||
}
|
||||
defer s.Close()
|
||||
|
||||
config, err := s.Config()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read service config: %w", err)
|
||||
}
|
||||
|
||||
if !strings.Contains(config.BinaryPathName, flag) {
|
||||
mainLog.Load().Debug().Msgf("Service flag %q not present in BinPath, skipping removal", flag)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Split BinPath into parts, find and remove the flag + its value (if any).
|
||||
parts := strings.Fields(config.BinaryPathName)
|
||||
var newParts []string
|
||||
for i := 0; i < len(parts); i++ {
|
||||
if parts[i] == flag {
|
||||
// Skip the flag. Also skip the next part if it's a value (not a flag).
|
||||
if i+1 < len(parts) && !strings.HasPrefix(parts[i+1], "-") {
|
||||
i++ // skip value too
|
||||
}
|
||||
continue
|
||||
}
|
||||
newParts = append(newParts, parts[i])
|
||||
}
|
||||
config.BinaryPathName = strings.Join(newParts, " ")
|
||||
|
||||
if err := s.UpdateConfig(config); err != nil {
|
||||
return fmt.Errorf("failed to update service config: %w", err)
|
||||
}
|
||||
|
||||
mainLog.Load().Info().Msgf("Removed %q from service BinPath", flag)
|
||||
return nil
|
||||
}
|
||||
22
cmd/cli/service_others.go
Normal file
22
cmd/cli/service_others.go
Normal file
@@ -0,0 +1,22 @@
|
||||
//go:build !windows
|
||||
|
||||
package cli
|
||||
|
||||
import (
|
||||
"os"
|
||||
)
|
||||
|
||||
func hasElevatedPrivilege() (bool, error) {
|
||||
return os.Geteuid() == 0, nil
|
||||
}
|
||||
|
||||
func openLogFile(path string, flags int) (*os.File, error) {
|
||||
return os.OpenFile(path, flags, os.FileMode(0o600))
|
||||
}
|
||||
|
||||
// hasLocalDnsServerRunning reports whether we are on Windows and having Dns server running.
|
||||
func hasLocalDnsServerRunning() bool { return false }
|
||||
|
||||
func ConfigureWindowsServiceFailureActions(serviceName string) error { return nil }
|
||||
|
||||
func isRunningOnDomainControllerWindows() (bool, int) { return false, 0 }
|
||||
28
cmd/cli/service_test.go
Normal file
28
cmd/cli/service_test.go
Normal file
@@ -0,0 +1,28 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func Test_ensureSystemdKillMode(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
unitFile string
|
||||
wantChange bool
|
||||
}{
|
||||
{"no KillMode", "[Service]\nExecStart=/bin/sleep 1", true},
|
||||
{"not KillMode=process", "[Service]\nExecStart=/bin/sleep 1\nKillMode=mixed", true},
|
||||
{"KillMode=process", "[Service]\nExecStart=/bin/sleep 1\nKillMode=process", false},
|
||||
{"invalid unit file", "[Service\nExecStart=/bin/sleep 1\nKillMode=process", false},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
if _, change := ensureSystemdKillMode(strings.NewReader(tc.unitFile)); tc.wantChange != change {
|
||||
t.Errorf("ensureSystemdKillMode(%q) = %v, want %v", tc.unitFile, change, tc.wantChange)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
228
cmd/cli/service_windows.go
Normal file
228
cmd/cli/service_windows.go
Normal file
@@ -0,0 +1,228 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"os"
|
||||
"reflect"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
"unsafe"
|
||||
|
||||
"github.com/microsoft/wmi/pkg/base/host"
|
||||
"github.com/microsoft/wmi/pkg/base/instance"
|
||||
"github.com/microsoft/wmi/pkg/base/query"
|
||||
"github.com/microsoft/wmi/pkg/constant"
|
||||
"golang.org/x/sys/windows"
|
||||
"golang.org/x/sys/windows/svc/mgr"
|
||||
)
|
||||
|
||||
func hasElevatedPrivilege() (bool, error) {
|
||||
var sid *windows.SID
|
||||
if err := windows.AllocateAndInitializeSid(
|
||||
&windows.SECURITY_NT_AUTHORITY,
|
||||
2,
|
||||
windows.SECURITY_BUILTIN_DOMAIN_RID,
|
||||
windows.DOMAIN_ALIAS_RID_ADMINS,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
&sid,
|
||||
); err != nil {
|
||||
return false, err
|
||||
}
|
||||
token := windows.Token(0)
|
||||
return token.IsMember(sid)
|
||||
}
|
||||
|
||||
// ConfigureWindowsServiceFailureActions checks if the given service
|
||||
// has the correct failure actions configured, and updates them if not.
|
||||
func ConfigureWindowsServiceFailureActions(serviceName string) error {
|
||||
if runtime.GOOS != "windows" {
|
||||
return nil // no-op on non-Windows
|
||||
}
|
||||
|
||||
m, err := mgr.Connect()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer m.Disconnect()
|
||||
|
||||
s, err := m.OpenService(serviceName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer s.Close()
|
||||
|
||||
// 1. Retrieve the current config
|
||||
cfg, err := s.Config()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 2. Update the Description
|
||||
cfg.Description = "A highly configurable, multi-protocol DNS forwarding proxy"
|
||||
|
||||
// 3. Apply the updated config
|
||||
if err := s.UpdateConfig(cfg); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Then proceed with existing actions, e.g. setting failure actions
|
||||
actions := []mgr.RecoveryAction{
|
||||
{Type: mgr.ServiceRestart, Delay: time.Second * 5}, // 5 seconds
|
||||
{Type: mgr.ServiceRestart, Delay: time.Second * 5}, // 5 seconds
|
||||
{Type: mgr.ServiceRestart, Delay: time.Second * 5}, // 5 seconds
|
||||
}
|
||||
|
||||
// Set the recovery actions (3 restarts, reset period = 120).
|
||||
err = s.SetRecoveryActions(actions, 120)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Ensure that failure actions are NOT triggered on user-initiated stops.
|
||||
var failureActionsFlag windows.SERVICE_FAILURE_ACTIONS_FLAG
|
||||
failureActionsFlag.FailureActionsOnNonCrashFailures = 0
|
||||
|
||||
if err := windows.ChangeServiceConfig2(
|
||||
s.Handle,
|
||||
windows.SERVICE_CONFIG_FAILURE_ACTIONS_FLAG,
|
||||
(*byte)(unsafe.Pointer(&failureActionsFlag)),
|
||||
); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func openLogFile(path string, mode int) (*os.File, error) {
|
||||
if len(path) == 0 {
|
||||
return nil, &os.PathError{Path: path, Op: "open", Err: syscall.ERROR_FILE_NOT_FOUND}
|
||||
}
|
||||
|
||||
pathP, err := syscall.UTF16PtrFromString(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var access uint32
|
||||
switch mode & (os.O_RDONLY | os.O_WRONLY | os.O_RDWR) {
|
||||
case os.O_RDONLY:
|
||||
access = windows.GENERIC_READ
|
||||
case os.O_WRONLY:
|
||||
access = windows.GENERIC_WRITE
|
||||
case os.O_RDWR:
|
||||
access = windows.GENERIC_READ | windows.GENERIC_WRITE
|
||||
}
|
||||
if mode&os.O_CREATE != 0 {
|
||||
access |= windows.GENERIC_WRITE
|
||||
}
|
||||
if mode&os.O_APPEND != 0 {
|
||||
access &^= windows.GENERIC_WRITE
|
||||
access |= windows.FILE_APPEND_DATA
|
||||
}
|
||||
|
||||
shareMode := uint32(syscall.FILE_SHARE_READ | syscall.FILE_SHARE_WRITE | syscall.FILE_SHARE_DELETE)
|
||||
|
||||
var sa *syscall.SecurityAttributes
|
||||
|
||||
var createMode uint32
|
||||
switch {
|
||||
case mode&(os.O_CREATE|os.O_EXCL) == (os.O_CREATE | os.O_EXCL):
|
||||
createMode = windows.CREATE_NEW
|
||||
case mode&(os.O_CREATE|os.O_TRUNC) == (os.O_CREATE | os.O_TRUNC):
|
||||
createMode = windows.CREATE_ALWAYS
|
||||
case mode&os.O_CREATE == os.O_CREATE:
|
||||
createMode = windows.OPEN_ALWAYS
|
||||
case mode&os.O_TRUNC == os.O_TRUNC:
|
||||
createMode = windows.TRUNCATE_EXISTING
|
||||
default:
|
||||
createMode = windows.OPEN_EXISTING
|
||||
}
|
||||
|
||||
handle, err := syscall.CreateFile(pathP, access, shareMode, sa, createMode, syscall.FILE_ATTRIBUTE_NORMAL, 0)
|
||||
if err != nil {
|
||||
return nil, &os.PathError{Path: path, Op: "open", Err: err}
|
||||
}
|
||||
|
||||
return os.NewFile(uintptr(handle), path), nil
|
||||
}
|
||||
|
||||
const processEntrySize = uint32(unsafe.Sizeof(windows.ProcessEntry32{}))
|
||||
|
||||
// hasLocalDnsServerRunning reports whether we are on Windows and having Dns server running.
|
||||
func hasLocalDnsServerRunning() bool {
|
||||
h, e := windows.CreateToolhelp32Snapshot(windows.TH32CS_SNAPPROCESS, 0)
|
||||
if e != nil {
|
||||
return false
|
||||
}
|
||||
defer windows.CloseHandle(h)
|
||||
p := windows.ProcessEntry32{Size: processEntrySize}
|
||||
for {
|
||||
e := windows.Process32Next(h, &p)
|
||||
if e != nil {
|
||||
return false
|
||||
}
|
||||
if strings.ToLower(windows.UTF16ToString(p.ExeFile[:])) == "dns.exe" {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func isRunningOnDomainControllerWindows() (bool, int) {
|
||||
whost := host.NewWmiLocalHost()
|
||||
q := query.NewWmiQuery("Win32_ComputerSystem")
|
||||
instances, err := instance.GetWmiInstancesFromHost(whost, string(constant.CimV2), q)
|
||||
if err != nil {
|
||||
mainLog.Load().Debug().Err(err).Msg("WMI query failed")
|
||||
return false, 0
|
||||
}
|
||||
if instances == nil {
|
||||
mainLog.Load().Debug().Msg("WMI query returned nil instances")
|
||||
return false, 0
|
||||
}
|
||||
defer instances.Close()
|
||||
|
||||
if len(instances) == 0 {
|
||||
mainLog.Load().Debug().Msg("no rows returned from Win32_ComputerSystem")
|
||||
return false, 0
|
||||
}
|
||||
|
||||
val, err := instances[0].GetProperty("DomainRole")
|
||||
if err != nil {
|
||||
mainLog.Load().Debug().Err(err).Msg("failed to get DomainRole property")
|
||||
return false, 0
|
||||
}
|
||||
if val == nil {
|
||||
mainLog.Load().Debug().Msg("DomainRole property is nil")
|
||||
return false, 0
|
||||
}
|
||||
|
||||
// Safely handle varied types: string or integer
|
||||
var roleInt int
|
||||
switch v := val.(type) {
|
||||
case string:
|
||||
// "4", "5", etc.
|
||||
parsed, parseErr := strconv.Atoi(v)
|
||||
if parseErr != nil {
|
||||
mainLog.Load().Debug().Err(parseErr).Msgf("failed to parse DomainRole value %q", v)
|
||||
return false, 0
|
||||
}
|
||||
roleInt = parsed
|
||||
case int8, int16, int32, int64:
|
||||
roleInt = int(reflect.ValueOf(v).Int())
|
||||
case uint8, uint16, uint32, uint64:
|
||||
roleInt = int(reflect.ValueOf(v).Uint())
|
||||
default:
|
||||
mainLog.Load().Debug().Msgf("unexpected DomainRole type: %T value=%v", v, v)
|
||||
return false, 0
|
||||
}
|
||||
|
||||
// Check if role indicates a domain controller
|
||||
isDC := roleInt == BackupDomainController || roleInt == PrimaryDomainController
|
||||
return isDC, roleInt
|
||||
}
|
||||
25
cmd/cli/service_windows_test.go
Normal file
25
cmd/cli/service_windows_test.go
Normal file
@@ -0,0 +1,25 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func Test_hasLocalDnsServerRunning(t *testing.T) {
|
||||
start := time.Now()
|
||||
hasDns := hasLocalDnsServerRunning()
|
||||
t.Logf("Using Windows API takes: %d", time.Since(start).Milliseconds())
|
||||
|
||||
start = time.Now()
|
||||
hasDnsPowershell := hasLocalDnsServerRunningPowershell()
|
||||
t.Logf("Using Powershell takes: %d", time.Since(start).Milliseconds())
|
||||
|
||||
if hasDns != hasDnsPowershell {
|
||||
t.Fatalf("result mismatch, want: %v, got: %v", hasDnsPowershell, hasDns)
|
||||
}
|
||||
}
|
||||
|
||||
func hasLocalDnsServerRunningPowershell() bool {
|
||||
_, err := powershell("Get-Process -Name DNS")
|
||||
return err == nil
|
||||
}
|
||||
126
cmd/cli/upstream_monitor.go
Normal file
126
cmd/cli/upstream_monitor.go
Normal file
@@ -0,0 +1,126 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld"
|
||||
)
|
||||
|
||||
const (
|
||||
// maxFailureRequest is the maximum failed queries allowed before an upstream is marked as down.
|
||||
maxFailureRequest = 50
|
||||
// checkUpstreamBackoffSleep is the time interval between each upstream checks.
|
||||
checkUpstreamBackoffSleep = 2 * time.Second
|
||||
)
|
||||
|
||||
// upstreamMonitor performs monitoring upstreams health.
|
||||
type upstreamMonitor struct {
|
||||
cfg *ctrld.Config
|
||||
|
||||
mu sync.RWMutex
|
||||
checking map[string]bool
|
||||
down map[string]bool
|
||||
failureReq map[string]uint64
|
||||
recovered map[string]bool
|
||||
|
||||
// failureTimerActive tracks if a timer is already running for a given upstream.
|
||||
failureTimerActive map[string]bool
|
||||
}
|
||||
|
||||
func newUpstreamMonitor(cfg *ctrld.Config) *upstreamMonitor {
|
||||
um := &upstreamMonitor{
|
||||
cfg: cfg,
|
||||
checking: make(map[string]bool),
|
||||
down: make(map[string]bool),
|
||||
failureReq: make(map[string]uint64),
|
||||
recovered: make(map[string]bool),
|
||||
failureTimerActive: make(map[string]bool),
|
||||
}
|
||||
for n := range cfg.Upstream {
|
||||
upstream := upstreamPrefix + n
|
||||
um.reset(upstream)
|
||||
}
|
||||
um.reset(upstreamOS)
|
||||
return um
|
||||
}
|
||||
|
||||
// increaseFailureCount increases failed queries count for an upstream by 1 and logs debug information.
|
||||
// It uses a timer to debounce failure detection, ensuring that an upstream is marked as down
|
||||
// within 10 seconds if failures persist, without spawning duplicate goroutines.
|
||||
func (um *upstreamMonitor) increaseFailureCount(upstream string) {
|
||||
um.mu.Lock()
|
||||
defer um.mu.Unlock()
|
||||
|
||||
if um.recovered[upstream] {
|
||||
mainLog.Load().Debug().Msgf("upstream %q is recovered, skipping failure count increase", upstream)
|
||||
return
|
||||
}
|
||||
|
||||
um.failureReq[upstream] += 1
|
||||
failedCount := um.failureReq[upstream]
|
||||
|
||||
// Log the updated failure count.
|
||||
mainLog.Load().Debug().Msgf("upstream %q failure count updated to %d", upstream, failedCount)
|
||||
|
||||
// If this is the first failure and no timer is running, start a 10-second timer.
|
||||
if failedCount == 1 && !um.failureTimerActive[upstream] {
|
||||
um.failureTimerActive[upstream] = true
|
||||
go func(upstream string) {
|
||||
time.Sleep(10 * time.Second)
|
||||
um.mu.Lock()
|
||||
defer um.mu.Unlock()
|
||||
// If no success occurred during the 10-second window (i.e. counter remains > 0)
|
||||
// and the upstream is not in a recovered state, mark it as down.
|
||||
if um.failureReq[upstream] > 0 && !um.recovered[upstream] {
|
||||
um.down[upstream] = true
|
||||
mainLog.Load().Warn().Msgf("upstream %q marked as down after 10 seconds (failure count: %d)", upstream, um.failureReq[upstream])
|
||||
}
|
||||
// Reset the timer flag so that a new timer can be spawned if needed.
|
||||
um.failureTimerActive[upstream] = false
|
||||
}(upstream)
|
||||
}
|
||||
|
||||
// If the failure count quickly reaches the threshold, mark the upstream as down immediately.
|
||||
if failedCount >= maxFailureRequest {
|
||||
um.down[upstream] = true
|
||||
mainLog.Load().Warn().Msgf("upstream %q marked as down immediately (failure count: %d)", upstream, failedCount)
|
||||
}
|
||||
}
|
||||
|
||||
// isDown reports whether the given upstream is being marked as down.
|
||||
func (um *upstreamMonitor) isDown(upstream string) bool {
|
||||
um.mu.Lock()
|
||||
defer um.mu.Unlock()
|
||||
|
||||
return um.down[upstream]
|
||||
}
|
||||
|
||||
// reset marks an upstream as up and set failed queries counter to zero.
|
||||
func (um *upstreamMonitor) reset(upstream string) {
|
||||
um.mu.Lock()
|
||||
um.failureReq[upstream] = 0
|
||||
um.down[upstream] = false
|
||||
um.recovered[upstream] = true
|
||||
um.mu.Unlock()
|
||||
go func() {
|
||||
// debounce the recovery to avoid incrementing failure counts already in flight
|
||||
time.Sleep(1 * time.Second)
|
||||
um.mu.Lock()
|
||||
um.recovered[upstream] = false
|
||||
um.mu.Unlock()
|
||||
}()
|
||||
}
|
||||
|
||||
// countHealthy returns the number of upstreams in the provided map that are considered healthy.
|
||||
func (um *upstreamMonitor) countHealthy(upstreams []string) int {
|
||||
var count int
|
||||
um.mu.RLock()
|
||||
for _, upstream := range upstreams {
|
||||
if !um.down[upstream] {
|
||||
count++
|
||||
}
|
||||
}
|
||||
um.mu.RUnlock()
|
||||
return count
|
||||
}
|
||||
258
cmd/cli/vpn_dns.go
Normal file
258
cmd/cli/vpn_dns.go
Normal file
@@ -0,0 +1,258 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"tailscale.com/net/netmon"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld"
|
||||
)
|
||||
|
||||
// vpnDNSExemption represents a VPN DNS server that needs pf/WFP exemption,
|
||||
// including the interface it was discovered on. The interface is used on macOS
|
||||
// to create interface-scoped pf exemptions that allow the VPN's local DNS
|
||||
// handler (e.g., Tailscale's MagicDNS Network Extension) to receive queries
|
||||
// from all processes — not just ctrld.
|
||||
type vpnDNSExemption struct {
|
||||
Server string // DNS server IP (e.g., "100.100.100.100")
|
||||
Interface string // Interface name from scutil (e.g., "utun11"), may be empty
|
||||
IsExitMode bool // True if this VPN is in exit/full-tunnel mode (all traffic routed through VPN)
|
||||
}
|
||||
|
||||
// vpnDNSExemptFunc is called when VPN DNS servers change, to update
|
||||
// the intercept layer (WFP/pf) to permit VPN DNS traffic.
|
||||
type vpnDNSExemptFunc func(exemptions []vpnDNSExemption) error
|
||||
|
||||
// vpnDNSManager tracks active VPN DNS configurations and provides
|
||||
// domain-to-upstream routing for VPN split DNS.
|
||||
type vpnDNSManager struct {
|
||||
mu sync.RWMutex
|
||||
configs []ctrld.VPNDNSConfig
|
||||
// Map of domain suffix → DNS servers for fast lookup
|
||||
routes map[string][]string
|
||||
// DNS servers from VPN interfaces that have no domain/suffix config.
|
||||
// These are NOT added to the global OS resolver. They're only used
|
||||
// as additional nameservers for queries that match split-DNS rules
|
||||
// (from ctrld config, AD domain, or VPN suffix config).
|
||||
domainlessServers []string
|
||||
// Called when VPN DNS server list changes, to update intercept exemptions.
|
||||
onServersChanged vpnDNSExemptFunc
|
||||
}
|
||||
|
||||
// newVPNDNSManager creates a new manager. Only call when dnsIntercept is active.
|
||||
// exemptFunc is called whenever VPN DNS servers are discovered/changed, to update
|
||||
// the OS-level intercept rules to permit ctrld's outbound queries to those IPs.
|
||||
func newVPNDNSManager(exemptFunc vpnDNSExemptFunc) *vpnDNSManager {
|
||||
return &vpnDNSManager{
|
||||
routes: make(map[string][]string),
|
||||
onServersChanged: exemptFunc,
|
||||
}
|
||||
}
|
||||
|
||||
// Refresh re-discovers VPN DNS configs from the OS.
|
||||
// Called on network change events.
|
||||
func (m *vpnDNSManager) Refresh(guardAgainstNoNameservers bool) {
|
||||
logger := mainLog.Load()
|
||||
|
||||
logger.Debug().Msg("Refreshing VPN DNS configurations")
|
||||
configs := ctrld.DiscoverVPNDNS(context.Background())
|
||||
|
||||
// Detect exit mode: if the default route goes through a VPN DNS interface,
|
||||
// the VPN is routing ALL traffic (exit node / full tunnel). This is more
|
||||
// reliable than scutil flag parsing because the routing table is the ground
|
||||
// truth for traffic flow, regardless of how the VPN presents itself in scutil.
|
||||
if dri, err := netmon.DefaultRouteInterface(); err == nil && dri != "" {
|
||||
for i := range configs {
|
||||
if configs[i].InterfaceName == dri {
|
||||
if !configs[i].IsExitMode {
|
||||
logger.Info().Msgf("VPN DNS on %s: default route interface match — EXIT MODE (route-based detection)", dri)
|
||||
}
|
||||
configs[i].IsExitMode = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
m.configs = configs
|
||||
m.routes = make(map[string][]string)
|
||||
|
||||
// Build domain -> DNS servers mapping
|
||||
for _, config := range configs {
|
||||
logger.Debug().Msgf("Processing VPN interface %s with %d domains and %d servers",
|
||||
config.InterfaceName, len(config.Domains), len(config.Servers))
|
||||
|
||||
for _, domain := range config.Domains {
|
||||
// Normalize domain: remove leading dot, Linux routing domain prefix (~),
|
||||
// and convert to lowercase.
|
||||
domain = strings.TrimPrefix(domain, "~")
|
||||
domain = strings.TrimPrefix(domain, ".")
|
||||
domain = strings.ToLower(domain)
|
||||
|
||||
if domain != "" {
|
||||
m.routes[domain] = append([]string{}, config.Servers...)
|
||||
logger.Debug().Msgf("Added VPN DNS route: %s -> %v", domain, config.Servers)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Collect unique VPN DNS exemptions (server + interface) for pf/WFP rules.
|
||||
type exemptionKey struct{ server, iface string }
|
||||
seen := make(map[exemptionKey]bool)
|
||||
var exemptions []vpnDNSExemption
|
||||
for _, config := range configs {
|
||||
for _, server := range config.Servers {
|
||||
key := exemptionKey{server, config.InterfaceName}
|
||||
if !seen[key] {
|
||||
seen[key] = true
|
||||
exemptions = append(exemptions, vpnDNSExemption{
|
||||
Server: server,
|
||||
Interface: config.InterfaceName,
|
||||
IsExitMode: config.IsExitMode,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Collect domain-less VPN DNS servers. These are NOT added to the global
|
||||
// OS resolver (that would pollute captive portal / DHCP flows). Instead,
|
||||
// they're stored separately and only used for queries that match existing
|
||||
// split-DNS rules (from ctrld config, AD domain, or VPN suffix config).
|
||||
var domainlessServers []string
|
||||
seen2 := make(map[string]bool)
|
||||
for _, config := range configs {
|
||||
if len(config.Domains) == 0 && len(config.Servers) > 0 {
|
||||
logger.Debug().Msgf("VPN interface %s has DNS servers but no domains, storing as split-rule fallback: %v",
|
||||
config.InterfaceName, config.Servers)
|
||||
for _, s := range config.Servers {
|
||||
if !seen2[s] {
|
||||
seen2[s] = true
|
||||
domainlessServers = append(domainlessServers, s)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
m.domainlessServers = domainlessServers
|
||||
|
||||
logger.Debug().Msgf("VPN DNS refresh completed: %d configs, %d routes, %d domainless servers, %d unique exemptions",
|
||||
len(m.configs), len(m.routes), len(m.domainlessServers), len(exemptions))
|
||||
|
||||
// Update intercept rules to permit VPN DNS traffic.
|
||||
// Always call onServersChanged — including when exemptions is empty — so that
|
||||
// stale exemptions from a previous VPN session get cleared on disconnect.
|
||||
if m.onServersChanged != nil {
|
||||
if err := m.onServersChanged(exemptions); err != nil {
|
||||
logger.Error().Err(err).Msg("Failed to update intercept exemptions for VPN DNS servers")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// UpstreamForDomain checks if the domain matches any VPN search domain.
|
||||
// Returns VPN DNS servers if matched, nil otherwise.
|
||||
func (m *vpnDNSManager) UpstreamForDomain(domain string) []string {
|
||||
if domain == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
domain = strings.TrimSuffix(domain, ".")
|
||||
domain = strings.ToLower(domain)
|
||||
|
||||
if servers, ok := m.routes[domain]; ok {
|
||||
return append([]string{}, servers...)
|
||||
}
|
||||
|
||||
for vpnDomain, servers := range m.routes {
|
||||
if strings.HasSuffix(domain, "."+vpnDomain) {
|
||||
return append([]string{}, servers...)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DomainlessServers returns VPN DNS servers that have no associated domains.
|
||||
// These should only be used for queries matching split-DNS rules, not for
|
||||
// general OS resolver queries (to avoid polluting captive portal / DHCP flows).
|
||||
func (m *vpnDNSManager) DomainlessServers() []string {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
return append([]string{}, m.domainlessServers...)
|
||||
}
|
||||
|
||||
// CurrentServers returns the current set of unique VPN DNS server IPs.
|
||||
func (m *vpnDNSManager) CurrentServers() []string {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
seen := make(map[string]bool)
|
||||
var servers []string
|
||||
for _, ss := range m.routes {
|
||||
for _, s := range ss {
|
||||
if !seen[s] {
|
||||
seen[s] = true
|
||||
servers = append(servers, s)
|
||||
}
|
||||
}
|
||||
}
|
||||
return servers
|
||||
}
|
||||
|
||||
// CurrentExemptions returns VPN DNS server + interface pairs for pf exemption rules.
|
||||
func (m *vpnDNSManager) CurrentExemptions() []vpnDNSExemption {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
type key struct{ server, iface string }
|
||||
seen := make(map[key]bool)
|
||||
var exemptions []vpnDNSExemption
|
||||
for _, config := range m.configs {
|
||||
for _, server := range config.Servers {
|
||||
k := key{server, config.InterfaceName}
|
||||
if !seen[k] {
|
||||
seen[k] = true
|
||||
exemptions = append(exemptions, vpnDNSExemption{
|
||||
Server: server,
|
||||
Interface: config.InterfaceName,
|
||||
IsExitMode: config.IsExitMode,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
return exemptions
|
||||
}
|
||||
|
||||
// Routes returns a copy of the current VPN DNS routes for debugging.
|
||||
func (m *vpnDNSManager) Routes() map[string][]string {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
routes := make(map[string][]string)
|
||||
for domain, servers := range m.routes {
|
||||
routes[domain] = append([]string{}, servers...)
|
||||
}
|
||||
return routes
|
||||
}
|
||||
|
||||
// upstreamConfigFor creates a legacy upstream configuration for the given VPN DNS server.
|
||||
func (m *vpnDNSManager) upstreamConfigFor(server string) *ctrld.UpstreamConfig {
|
||||
// Use net.JoinHostPort to correctly handle both IPv4 and IPv6 addresses.
|
||||
// Previously, the strings.Contains(":") check would skip appending ":53"
|
||||
// for IPv6 addresses (they contain colons), leaving a bare address like
|
||||
// "2a0d:6fc0:9b0:3600::1" which net.Dial rejects with "too many colons".
|
||||
// net.JoinHostPort produces "[2a0d:6fc0:9b0:3600::1]:53" as required.
|
||||
endpoint := net.JoinHostPort(server, "53")
|
||||
|
||||
return &ctrld.UpstreamConfig{
|
||||
Name: "VPN DNS",
|
||||
Type: ctrld.ResolverTypeLegacy,
|
||||
Endpoint: endpoint,
|
||||
Timeout: 2000,
|
||||
}
|
||||
}
|
||||
20
cmd/cli/winres/winres.json
Normal file
20
cmd/cli/winres/winres.json
Normal file
@@ -0,0 +1,20 @@
|
||||
{
|
||||
"RT_VERSION": {
|
||||
"#1": {
|
||||
"0000": {
|
||||
"fixed": {
|
||||
"file_version": "0.0.0.1"
|
||||
},
|
||||
"info": {
|
||||
"0409": {
|
||||
"CompanyName": "ControlD Inc",
|
||||
"FileDescription": "Control D DNS daemon",
|
||||
"ProductName": "ctrld",
|
||||
"InternalName": "ctrld",
|
||||
"LegalCopyright": "ControlD Inc 2024"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
4
cmd/cli/winres_windows.go
Normal file
4
cmd/cli/winres_windows.go
Normal file
@@ -0,0 +1,4 @@
|
||||
//go:generate go-winres make --product-version=git-tag --file-version=git-tag
|
||||
package cli
|
||||
|
||||
// Placeholder file for windows builds.
|
||||
959
cmd/ctrld/cli.go
959
cmd/ctrld/cli.go
@@ -1,959 +0,0 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/x509"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/cuonglm/osinfo"
|
||||
"github.com/fsnotify/fsnotify"
|
||||
"github.com/go-playground/validator/v10"
|
||||
"github.com/kardianos/service"
|
||||
"github.com/miekg/dns"
|
||||
"github.com/pelletier/go-toml/v2"
|
||||
"github.com/spf13/cobra"
|
||||
"github.com/spf13/viper"
|
||||
"tailscale.com/logtail/backoff"
|
||||
"tailscale.com/net/interfaces"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld"
|
||||
"github.com/Control-D-Inc/ctrld/internal/certs"
|
||||
"github.com/Control-D-Inc/ctrld/internal/controld"
|
||||
ctrldnet "github.com/Control-D-Inc/ctrld/internal/net"
|
||||
"github.com/Control-D-Inc/ctrld/internal/router"
|
||||
)
|
||||
|
||||
var (
|
||||
version = "dev"
|
||||
commit = "none"
|
||||
)
|
||||
|
||||
var (
|
||||
v = viper.NewWithOptions(viper.KeyDelimiter("::"))
|
||||
defaultConfigWritten = false
|
||||
defaultConfigFile = "ctrld.toml"
|
||||
rootCertPool *x509.CertPool
|
||||
)
|
||||
|
||||
var basicModeFlags = []string{"listen", "primary_upstream", "secondary_upstream", "domains"}
|
||||
|
||||
func isNoConfigStart(cmd *cobra.Command) bool {
|
||||
for _, flagName := range basicModeFlags {
|
||||
if cmd.Flags().Lookup(flagName).Changed {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
const rootShortDesc = `
|
||||
__ .__ .___
|
||||
_____/ |________| | __| _/
|
||||
_/ ___\ __\_ __ \ | / __ |
|
||||
\ \___| | | | \/ |__/ /_/ |
|
||||
\___ >__| |__| |____/\____ |
|
||||
\/ dns forwarding proxy \/
|
||||
`
|
||||
|
||||
var rootCmd = &cobra.Command{
|
||||
Use: "ctrld",
|
||||
Short: strings.TrimLeft(rootShortDesc, "\n"),
|
||||
Version: curVersion(),
|
||||
PreRun: func(cmd *cobra.Command, args []string) {
|
||||
initConsoleLogging()
|
||||
},
|
||||
}
|
||||
|
||||
func curVersion() string {
|
||||
if version != "dev" && !strings.HasPrefix(version, "v") {
|
||||
version = "v" + version
|
||||
}
|
||||
if len(commit) > 7 {
|
||||
commit = commit[:7]
|
||||
}
|
||||
return fmt.Sprintf("%s-%s", version, commit)
|
||||
}
|
||||
|
||||
func initCLI() {
|
||||
// Enable opening via explorer.exe on Windows.
|
||||
// See: https://github.com/spf13/cobra/issues/844.
|
||||
cobra.MousetrapHelpText = ""
|
||||
cobra.EnableCommandSorting = false
|
||||
|
||||
rootCmd.PersistentFlags().CountVarP(
|
||||
&verbose,
|
||||
"verbose",
|
||||
"v",
|
||||
`verbose log output, "-v" basic logging, "-vv" debug level logging`,
|
||||
)
|
||||
rootCmd.PersistentFlags().BoolVarP(
|
||||
&silent,
|
||||
"silent",
|
||||
"s",
|
||||
false,
|
||||
`do not write any log output`,
|
||||
)
|
||||
rootCmd.SetHelpCommand(&cobra.Command{Hidden: true})
|
||||
rootCmd.CompletionOptions.HiddenDefaultCmd = true
|
||||
|
||||
runCmd := &cobra.Command{
|
||||
Use: "run",
|
||||
Short: "Run the DNS proxy server",
|
||||
Args: cobra.NoArgs,
|
||||
PreRun: func(cmd *cobra.Command, args []string) {
|
||||
initConsoleLogging()
|
||||
},
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
if daemon && runtime.GOOS == "windows" {
|
||||
mainLog.Fatal().Msg("Cannot run in daemon mode. Please install a Windows service.")
|
||||
}
|
||||
|
||||
waitCh := make(chan struct{})
|
||||
stopCh := make(chan struct{})
|
||||
if !daemon {
|
||||
// We need to call s.Run() as soon as possible to response to the OS manager, so it
|
||||
// can see ctrld is running and don't mark ctrld as failed service.
|
||||
go func() {
|
||||
p := &prog{
|
||||
waitCh: waitCh,
|
||||
stopCh: stopCh,
|
||||
}
|
||||
s, err := service.New(p, svcConfig)
|
||||
if err != nil {
|
||||
mainLog.Fatal().Err(err).Msg("failed create new service")
|
||||
}
|
||||
s = newService(s)
|
||||
if err := s.Run(); err != nil {
|
||||
mainLog.Error().Err(err).Msg("failed to start service")
|
||||
}
|
||||
}()
|
||||
}
|
||||
noConfigStart := isNoConfigStart(cmd)
|
||||
writeDefaultConfig := !noConfigStart && configBase64 == ""
|
||||
tryReadingConfig(writeDefaultConfig)
|
||||
|
||||
readBase64Config(configBase64)
|
||||
processNoConfigFlags(noConfigStart)
|
||||
if err := v.Unmarshal(&cfg); err != nil {
|
||||
mainLog.Fatal().Msgf("failed to unmarshal config: %v", err)
|
||||
}
|
||||
|
||||
mainLog.Info().Msgf("starting ctrld %s", curVersion())
|
||||
oi := osinfo.New()
|
||||
mainLog.Info().Msgf("os: %s", oi.String())
|
||||
|
||||
// Wait for network up.
|
||||
if !ctrldnet.Up() {
|
||||
mainLog.Fatal().Msg("network is not up yet")
|
||||
}
|
||||
processLogAndCacheFlags()
|
||||
// Log config do not have thing to validate, so it's safe to init log here,
|
||||
// so it's able to log information in processCDFlags.
|
||||
initLogging()
|
||||
|
||||
if setupRouter {
|
||||
s, errCh := runDNSServerForNTPD(router.ListenAddress())
|
||||
if err := router.PreRun(); err != nil {
|
||||
mainLog.Fatal().Err(err).Msg("failed to perform router pre-start check")
|
||||
}
|
||||
if err := s.Shutdown(); err != nil && errCh != nil {
|
||||
mainLog.Fatal().Err(err).Msg("failed to shutdown dns server for ntpd")
|
||||
}
|
||||
}
|
||||
|
||||
processCDFlags()
|
||||
if err := ctrld.ValidateConfig(validator.New(), &cfg); err != nil {
|
||||
mainLog.Fatal().Msgf("invalid config: %v", err)
|
||||
}
|
||||
initCache()
|
||||
|
||||
if daemon {
|
||||
exe, err := os.Executable()
|
||||
if err != nil {
|
||||
mainLog.Error().Err(err).Msg("failed to find the binary")
|
||||
os.Exit(1)
|
||||
}
|
||||
curDir, err := os.Getwd()
|
||||
if err != nil {
|
||||
mainLog.Error().Err(err).Msg("failed to get current working directory")
|
||||
os.Exit(1)
|
||||
}
|
||||
// If running as daemon, re-run the command in background, with daemon off.
|
||||
cmd := exec.Command(exe, append(os.Args[1:], "-d=false")...)
|
||||
cmd.Dir = curDir
|
||||
if err := cmd.Start(); err != nil {
|
||||
mainLog.Error().Err(err).Msg("failed to start process as daemon")
|
||||
os.Exit(1)
|
||||
}
|
||||
mainLog.Info().Int("pid", cmd.Process.Pid).Msg("DNS proxy started")
|
||||
os.Exit(0)
|
||||
}
|
||||
|
||||
if setupRouter {
|
||||
switch platform := router.Name(); {
|
||||
case platform == router.DDWrt:
|
||||
rootCertPool = certs.CACertPool()
|
||||
fallthrough
|
||||
case platform != "":
|
||||
mainLog.Debug().Msg("Router setup")
|
||||
err := router.Configure(&cfg)
|
||||
if errors.Is(err, router.ErrNotSupported) {
|
||||
unsupportedPlatformHelp(cmd)
|
||||
os.Exit(1)
|
||||
}
|
||||
if err != nil {
|
||||
mainLog.Fatal().Err(err).Msg("failed to configure router")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
close(waitCh)
|
||||
<-stopCh
|
||||
},
|
||||
}
|
||||
runCmd.Flags().BoolVarP(&daemon, "daemon", "d", false, "Run as daemon")
|
||||
runCmd.Flags().StringVarP(&configPath, "config", "c", "", "Path to config file")
|
||||
runCmd.Flags().StringVarP(&configBase64, "base64_config", "", "", "Base64 encoded config")
|
||||
runCmd.Flags().StringVarP(&listenAddress, "listen", "", "", "Listener address and port, in format: address:port")
|
||||
runCmd.Flags().StringVarP(&primaryUpstream, "primary_upstream", "", "", "Primary upstream endpoint")
|
||||
runCmd.Flags().StringVarP(&secondaryUpstream, "secondary_upstream", "", "", "Secondary upstream endpoint")
|
||||
runCmd.Flags().StringSliceVarP(&domains, "domains", "", nil, "List of domain to apply in a split DNS policy")
|
||||
runCmd.Flags().StringVarP(&logPath, "log", "", "", "Path to log file")
|
||||
runCmd.Flags().IntVarP(&cacheSize, "cache_size", "", 0, "Enable cache with size items")
|
||||
runCmd.Flags().StringVarP(&cdUID, "cd", "", "", "Control D resolver uid")
|
||||
runCmd.Flags().BoolVarP(&cdDev, "dev", "", false, "Use Control D dev resolver/domain")
|
||||
_ = runCmd.Flags().MarkHidden("dev")
|
||||
runCmd.Flags().StringVarP(&homedir, "homedir", "", "", "")
|
||||
_ = runCmd.Flags().MarkHidden("homedir")
|
||||
runCmd.Flags().StringVarP(&iface, "iface", "", "", `Update DNS setting for iface, "auto" means the default interface gateway`)
|
||||
_ = runCmd.Flags().MarkHidden("iface")
|
||||
runCmd.Flags().BoolVarP(&setupRouter, "router", "", false, `setup for running on router platforms`)
|
||||
_ = runCmd.Flags().MarkHidden("router")
|
||||
|
||||
rootCmd.AddCommand(runCmd)
|
||||
|
||||
startCmd := &cobra.Command{
|
||||
PreRun: func(cmd *cobra.Command, args []string) {
|
||||
initConsoleLogging()
|
||||
checkHasElevatedPrivilege()
|
||||
},
|
||||
Use: "start",
|
||||
Short: "Install and start the ctrld service",
|
||||
Args: cobra.NoArgs,
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
sc := &service.Config{}
|
||||
*sc = *svcConfig
|
||||
osArgs := os.Args[2:]
|
||||
if os.Args[1] == "service" {
|
||||
osArgs = os.Args[3:]
|
||||
}
|
||||
setDependencies(sc)
|
||||
sc.Arguments = append([]string{"run"}, osArgs...)
|
||||
if err := router.ConfigureService(sc); err != nil {
|
||||
mainLog.Fatal().Err(err).Msg("failed to configure service on router")
|
||||
}
|
||||
|
||||
// No config path, generating config in HOME directory.
|
||||
noConfigStart := isNoConfigStart(cmd)
|
||||
writeDefaultConfig := !noConfigStart && configBase64 == ""
|
||||
if configPath != "" {
|
||||
v.SetConfigFile(configPath)
|
||||
}
|
||||
if dir, err := userHomeDir(); err == nil {
|
||||
setWorkingDirectory(sc, dir)
|
||||
if configPath == "" && writeDefaultConfig {
|
||||
defaultConfigFile = filepath.Join(dir, defaultConfigFile)
|
||||
}
|
||||
sc.Arguments = append(sc.Arguments, "--homedir="+dir)
|
||||
}
|
||||
|
||||
tryReadingConfig(writeDefaultConfig)
|
||||
|
||||
if err := v.Unmarshal(&cfg); err != nil {
|
||||
mainLog.Fatal().Msgf("failed to unmarshal config: %v", err)
|
||||
}
|
||||
|
||||
logPath := cfg.Service.LogPath
|
||||
cfg.Service.LogPath = ""
|
||||
initLogging()
|
||||
cfg.Service.LogPath = logPath
|
||||
|
||||
processCDFlags()
|
||||
|
||||
if err := ctrld.ValidateConfig(validator.New(), &cfg); err != nil {
|
||||
mainLog.Fatal().Msgf("invalid config: %v", err)
|
||||
}
|
||||
|
||||
// Explicitly passing config, so on system where home directory could not be obtained,
|
||||
// or sub-process env is different with the parent, we still behave correctly and use
|
||||
// the expected config file.
|
||||
if configPath == "" {
|
||||
sc.Arguments = append(sc.Arguments, "--config="+defaultConfigFile)
|
||||
}
|
||||
|
||||
prog := &prog{}
|
||||
s, err := service.New(prog, sc)
|
||||
if err != nil {
|
||||
mainLog.Error().Msg(err.Error())
|
||||
return
|
||||
}
|
||||
s = newService(s)
|
||||
tasks := []task{
|
||||
{s.Stop, false},
|
||||
{s.Uninstall, false},
|
||||
{s.Install, false},
|
||||
{s.Start, true},
|
||||
}
|
||||
if doTasks(tasks) {
|
||||
if err := router.PostInstall(svcConfig); err != nil {
|
||||
mainLog.Warn().Err(err).Msg("post installation failed, please check system/service log for details error")
|
||||
return
|
||||
}
|
||||
status, err := serviceStatus(s)
|
||||
if err != nil {
|
||||
mainLog.Warn().Err(err).Msg("could not get service status")
|
||||
return
|
||||
}
|
||||
|
||||
domain := cfg.Upstream["0"].VerifyDomain()
|
||||
status = selfCheckStatus(status, domain)
|
||||
switch status {
|
||||
case service.StatusRunning:
|
||||
mainLog.Notice().Msg("Service started")
|
||||
default:
|
||||
mainLog.Error().Msg("Service did not start, please check system/service log for details error")
|
||||
if runtime.GOOS == "linux" {
|
||||
prog.resetDNS()
|
||||
}
|
||||
os.Exit(1)
|
||||
}
|
||||
prog.setDNS()
|
||||
}
|
||||
},
|
||||
}
|
||||
// Keep these flags in sync with runCmd above, except for "-d".
|
||||
startCmd.Flags().StringVarP(&configPath, "config", "c", "", "Path to config file")
|
||||
startCmd.Flags().StringVarP(&configBase64, "base64_config", "", "", "Base64 encoded config")
|
||||
startCmd.Flags().StringVarP(&listenAddress, "listen", "", "", "Listener address and port, in format: address:port")
|
||||
startCmd.Flags().StringVarP(&primaryUpstream, "primary_upstream", "", "", "Primary upstream endpoint")
|
||||
startCmd.Flags().StringVarP(&secondaryUpstream, "secondary_upstream", "", "", "Secondary upstream endpoint")
|
||||
startCmd.Flags().StringSliceVarP(&domains, "domains", "", nil, "List of domain to apply in a split DNS policy")
|
||||
startCmd.Flags().StringVarP(&logPath, "log", "", "", "Path to log file")
|
||||
startCmd.Flags().IntVarP(&cacheSize, "cache_size", "", 0, "Enable cache with size items")
|
||||
startCmd.Flags().StringVarP(&cdUID, "cd", "", "", "Control D resolver uid")
|
||||
startCmd.Flags().BoolVarP(&cdDev, "dev", "", false, "Use Control D dev resolver/domain")
|
||||
_ = startCmd.Flags().MarkHidden("dev")
|
||||
startCmd.Flags().StringVarP(&iface, "iface", "", "", `Update DNS setting for iface, "auto" means the default interface gateway`)
|
||||
startCmd.Flags().BoolVarP(&setupRouter, "router", "", false, `setup for running on router platforms`)
|
||||
_ = startCmd.Flags().MarkHidden("router")
|
||||
|
||||
stopCmd := &cobra.Command{
|
||||
PreRun: func(cmd *cobra.Command, args []string) {
|
||||
initConsoleLogging()
|
||||
checkHasElevatedPrivilege()
|
||||
},
|
||||
Use: "stop",
|
||||
Short: "Stop the ctrld service",
|
||||
Args: cobra.NoArgs,
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
prog := &prog{}
|
||||
s, err := service.New(prog, svcConfig)
|
||||
if err != nil {
|
||||
mainLog.Error().Msg(err.Error())
|
||||
return
|
||||
}
|
||||
s = newService(s)
|
||||
initLogging()
|
||||
if doTasks([]task{{s.Stop, true}}) {
|
||||
prog.resetDNS()
|
||||
mainLog.Notice().Msg("Service stopped")
|
||||
}
|
||||
},
|
||||
}
|
||||
stopCmd.Flags().StringVarP(&iface, "iface", "", "", `Reset DNS setting for iface, "auto" means the default interface gateway`)
|
||||
|
||||
restartCmd := &cobra.Command{
|
||||
PreRun: func(cmd *cobra.Command, args []string) {
|
||||
initConsoleLogging()
|
||||
checkHasElevatedPrivilege()
|
||||
},
|
||||
Use: "restart",
|
||||
Short: "Restart the ctrld service",
|
||||
Args: cobra.NoArgs,
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
s, err := service.New(&prog{}, svcConfig)
|
||||
if err != nil {
|
||||
mainLog.Error().Msg(err.Error())
|
||||
return
|
||||
}
|
||||
s = newService(s)
|
||||
initLogging()
|
||||
if doTasks([]task{{s.Restart, true}}) {
|
||||
mainLog.Notice().Msg("Service restarted")
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
statusCmd := &cobra.Command{
|
||||
Use: "status",
|
||||
Short: "Show status of the ctrld service",
|
||||
Args: cobra.NoArgs,
|
||||
PreRun: func(cmd *cobra.Command, args []string) {
|
||||
initConsoleLogging()
|
||||
},
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
s, err := service.New(&prog{}, svcConfig)
|
||||
if err != nil {
|
||||
mainLog.Error().Msg(err.Error())
|
||||
return
|
||||
}
|
||||
s = newService(s)
|
||||
status, err := serviceStatus(s)
|
||||
if err != nil {
|
||||
mainLog.Error().Msg(err.Error())
|
||||
os.Exit(1)
|
||||
}
|
||||
switch status {
|
||||
case service.StatusUnknown:
|
||||
mainLog.Notice().Msg("Unknown status")
|
||||
os.Exit(2)
|
||||
case service.StatusRunning:
|
||||
mainLog.Notice().Msg("Service is running")
|
||||
os.Exit(0)
|
||||
case service.StatusStopped:
|
||||
mainLog.Notice().Msg("Service is stopped")
|
||||
os.Exit(1)
|
||||
}
|
||||
},
|
||||
}
|
||||
if runtime.GOOS == "darwin" {
|
||||
// On darwin, running status command without privileges may return wrong information.
|
||||
statusCmd.PreRun = func(cmd *cobra.Command, args []string) {
|
||||
initConsoleLogging()
|
||||
checkHasElevatedPrivilege()
|
||||
}
|
||||
}
|
||||
|
||||
uninstallCmd := &cobra.Command{
|
||||
PreRun: func(cmd *cobra.Command, args []string) {
|
||||
initConsoleLogging()
|
||||
checkHasElevatedPrivilege()
|
||||
},
|
||||
Use: "uninstall",
|
||||
Short: "Stop and uninstall the ctrld service",
|
||||
Long: `Stop and uninstall the ctrld service.
|
||||
|
||||
NOTE: Uninstalling will set DNS to values provided by DHCP.`,
|
||||
Args: cobra.NoArgs,
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
prog := &prog{}
|
||||
s, err := service.New(prog, svcConfig)
|
||||
if err != nil {
|
||||
mainLog.Error().Msg(err.Error())
|
||||
return
|
||||
}
|
||||
tasks := []task{
|
||||
{s.Stop, false},
|
||||
{s.Uninstall, true},
|
||||
}
|
||||
initLogging()
|
||||
if doTasks(tasks) {
|
||||
if iface == "" {
|
||||
iface = "auto"
|
||||
}
|
||||
prog.resetDNS()
|
||||
mainLog.Debug().Msg("Router cleanup")
|
||||
if err := router.Cleanup(svcConfig); err != nil {
|
||||
mainLog.Warn().Err(err).Msg("could not cleanup router")
|
||||
}
|
||||
mainLog.Notice().Msg("Service uninstalled")
|
||||
return
|
||||
}
|
||||
},
|
||||
}
|
||||
uninstallCmd.Flags().StringVarP(&iface, "iface", "", "", `Reset DNS setting for iface, use "auto" for the default gateway interface`)
|
||||
|
||||
listIfacesCmd := &cobra.Command{
|
||||
Use: "list",
|
||||
Short: "List network interfaces of the host",
|
||||
Args: cobra.NoArgs,
|
||||
PreRun: func(cmd *cobra.Command, args []string) {
|
||||
initConsoleLogging()
|
||||
},
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
err := interfaces.ForeachInterface(func(i interfaces.Interface, prefixes []netip.Prefix) {
|
||||
fmt.Printf("Index : %d\n", i.Index)
|
||||
fmt.Printf("Name : %s\n", i.Name)
|
||||
addrs, _ := i.Addrs()
|
||||
for i, ipaddr := range addrs {
|
||||
if i == 0 {
|
||||
fmt.Printf("Addrs : %v\n", ipaddr)
|
||||
continue
|
||||
}
|
||||
fmt.Printf(" %v\n", ipaddr)
|
||||
}
|
||||
for i, dns := range currentDNS(i.Interface) {
|
||||
if i == 0 {
|
||||
fmt.Printf("DNS : %s\n", dns)
|
||||
continue
|
||||
}
|
||||
fmt.Printf(" : %s\n", dns)
|
||||
}
|
||||
println()
|
||||
})
|
||||
if err != nil {
|
||||
mainLog.Error().Msg(err.Error())
|
||||
}
|
||||
},
|
||||
}
|
||||
interfacesCmd := &cobra.Command{
|
||||
Use: "interfaces",
|
||||
Short: "Manage network interfaces",
|
||||
Args: cobra.OnlyValidArgs,
|
||||
ValidArgs: []string{
|
||||
listIfacesCmd.Use,
|
||||
},
|
||||
}
|
||||
interfacesCmd.AddCommand(listIfacesCmd)
|
||||
|
||||
serviceCmd := &cobra.Command{
|
||||
Use: "service",
|
||||
Short: "Manage ctrld service",
|
||||
Args: cobra.OnlyValidArgs,
|
||||
ValidArgs: []string{
|
||||
statusCmd.Use,
|
||||
stopCmd.Use,
|
||||
restartCmd.Use,
|
||||
statusCmd.Use,
|
||||
uninstallCmd.Use,
|
||||
interfacesCmd.Use,
|
||||
},
|
||||
}
|
||||
serviceCmd.AddCommand(startCmd)
|
||||
serviceCmd.AddCommand(stopCmd)
|
||||
serviceCmd.AddCommand(restartCmd)
|
||||
serviceCmd.AddCommand(statusCmd)
|
||||
serviceCmd.AddCommand(uninstallCmd)
|
||||
serviceCmd.AddCommand(interfacesCmd)
|
||||
rootCmd.AddCommand(serviceCmd)
|
||||
startCmdAlias := &cobra.Command{
|
||||
PreRun: func(cmd *cobra.Command, args []string) {
|
||||
initConsoleLogging()
|
||||
checkHasElevatedPrivilege()
|
||||
},
|
||||
Use: "start",
|
||||
Short: "Quick start service and configure DNS on interface",
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
if !cmd.Flags().Changed("iface") {
|
||||
os.Args = append(os.Args, "--iface="+ifaceStartStop)
|
||||
}
|
||||
iface = ifaceStartStop
|
||||
startCmd.Run(cmd, args)
|
||||
},
|
||||
}
|
||||
startCmdAlias.Flags().StringVarP(&ifaceStartStop, "iface", "", "auto", `Update DNS setting for iface, "auto" means the default interface gateway`)
|
||||
startCmdAlias.Flags().AddFlagSet(startCmd.Flags())
|
||||
rootCmd.AddCommand(startCmdAlias)
|
||||
stopCmdAlias := &cobra.Command{
|
||||
PreRun: func(cmd *cobra.Command, args []string) {
|
||||
initConsoleLogging()
|
||||
checkHasElevatedPrivilege()
|
||||
},
|
||||
Use: "stop",
|
||||
Short: "Quick stop service and remove DNS from interface",
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
if !cmd.Flags().Changed("iface") {
|
||||
os.Args = append(os.Args, "--iface="+ifaceStartStop)
|
||||
}
|
||||
iface = ifaceStartStop
|
||||
stopCmd.Run(cmd, args)
|
||||
},
|
||||
}
|
||||
stopCmdAlias.Flags().StringVarP(&ifaceStartStop, "iface", "", "auto", `Reset DNS setting for iface, "auto" means the default interface gateway`)
|
||||
stopCmdAlias.Flags().AddFlagSet(stopCmd.Flags())
|
||||
rootCmd.AddCommand(stopCmdAlias)
|
||||
}
|
||||
|
||||
func writeConfigFile() error {
|
||||
if cfu := v.ConfigFileUsed(); cfu != "" {
|
||||
defaultConfigFile = cfu
|
||||
} else if configPath != "" {
|
||||
defaultConfigFile = configPath
|
||||
}
|
||||
f, err := os.OpenFile(defaultConfigFile, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, os.FileMode(0o644))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer f.Close()
|
||||
if cdUID != "" {
|
||||
if _, err := f.WriteString("# AUTO-GENERATED VIA CD FLAG - DO NOT MODIFY\n\n"); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
enc := toml.NewEncoder(f).SetIndentTables(true)
|
||||
if err := enc.Encode(&cfg); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := f.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func readConfigFile(writeDefaultConfig bool) bool {
|
||||
// If err == nil, there's a config supplied via `--config`, no default config written.
|
||||
err := v.ReadInConfig()
|
||||
if err == nil {
|
||||
mainLog.Info().Msg("loading config file from: " + v.ConfigFileUsed())
|
||||
defaultConfigFile = v.ConfigFileUsed()
|
||||
return true
|
||||
}
|
||||
|
||||
if !writeDefaultConfig {
|
||||
return false
|
||||
}
|
||||
|
||||
// If error is viper.ConfigFileNotFoundError, write default config.
|
||||
if _, ok := err.(viper.ConfigFileNotFoundError); ok {
|
||||
if err := v.Unmarshal(&cfg); err != nil {
|
||||
mainLog.Fatal().Msgf("failed to unmarshal default config: %v", err)
|
||||
}
|
||||
if err := writeConfigFile(); err != nil {
|
||||
mainLog.Fatal().Msgf("failed to write default config file: %v", err)
|
||||
} else {
|
||||
fp, err := filepath.Abs(defaultConfigFile)
|
||||
if err != nil {
|
||||
mainLog.Fatal().Msgf("failed to get default config file path: %v", err)
|
||||
}
|
||||
mainLog.Info().Msg("writing default config file to: " + fp)
|
||||
}
|
||||
defaultConfigWritten = true
|
||||
return false
|
||||
}
|
||||
// Otherwise, report fatal error and exit.
|
||||
mainLog.Fatal().Msgf("failed to decode config file: %v", err)
|
||||
return false
|
||||
}
|
||||
|
||||
func readBase64Config(configBase64 string) {
|
||||
if configBase64 == "" {
|
||||
return
|
||||
}
|
||||
configStr, err := base64.StdEncoding.DecodeString(configBase64)
|
||||
if err != nil {
|
||||
mainLog.Fatal().Msgf("invalid base64 config: %v", err)
|
||||
}
|
||||
if err := v.ReadConfig(bytes.NewReader(configStr)); err != nil {
|
||||
mainLog.Fatal().Msgf("failed to read base64 config: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func processNoConfigFlags(noConfigStart bool) {
|
||||
if !noConfigStart {
|
||||
return
|
||||
}
|
||||
if listenAddress == "" || primaryUpstream == "" {
|
||||
mainLog.Fatal().Msg(`"listen" and "primary_upstream" flags must be set in no config mode`)
|
||||
}
|
||||
processListenFlag()
|
||||
|
||||
endpointAndTyp := func(endpoint string) (string, string) {
|
||||
typ := ctrld.ResolverTypeFromEndpoint(endpoint)
|
||||
return strings.TrimPrefix(endpoint, "quic://"), typ
|
||||
}
|
||||
pEndpoint, pType := endpointAndTyp(primaryUpstream)
|
||||
upstream := map[string]*ctrld.UpstreamConfig{
|
||||
"0": {
|
||||
Name: pEndpoint,
|
||||
Endpoint: pEndpoint,
|
||||
Type: pType,
|
||||
Timeout: 5000,
|
||||
},
|
||||
}
|
||||
if secondaryUpstream != "" {
|
||||
sEndpoint, sType := endpointAndTyp(secondaryUpstream)
|
||||
upstream["1"] = &ctrld.UpstreamConfig{
|
||||
Name: sEndpoint,
|
||||
Endpoint: sEndpoint,
|
||||
Type: sType,
|
||||
Timeout: 5000,
|
||||
}
|
||||
rules := make([]ctrld.Rule, 0, len(domains))
|
||||
for _, domain := range domains {
|
||||
rules = append(rules, ctrld.Rule{domain: []string{"upstream.1"}})
|
||||
}
|
||||
lc := v.Get("listener").(map[string]*ctrld.ListenerConfig)["0"]
|
||||
lc.Policy = &ctrld.ListenerPolicyConfig{Name: "My Policy", Rules: rules}
|
||||
}
|
||||
v.Set("upstream", upstream)
|
||||
}
|
||||
|
||||
func processCDFlags() {
|
||||
if cdUID == "" {
|
||||
return
|
||||
}
|
||||
if iface == "" {
|
||||
iface = "auto"
|
||||
}
|
||||
logger := mainLog.With().Str("mode", "cd").Logger()
|
||||
logger.Info().Msgf("fetching Controld D configuration from API: %s", cdUID)
|
||||
resolverConfig, err := controld.FetchResolverConfig(cdUID, rootCmd.Version, cdDev)
|
||||
if uer, ok := err.(*controld.UtilityErrorResponse); ok && uer.ErrorField.Code == controld.InvalidConfigCode {
|
||||
s, err := service.New(&prog{}, svcConfig)
|
||||
if err != nil {
|
||||
logger.Warn().Err(err).Msg("failed to create new service")
|
||||
return
|
||||
}
|
||||
|
||||
if netIface, _ := netInterface(iface); netIface != nil {
|
||||
if err := restoreNetworkManager(); err != nil {
|
||||
logger.Error().Err(err).Msg("could not restore NetworkManager")
|
||||
return
|
||||
}
|
||||
logger.Debug().Str("iface", netIface.Name).Msg("Restoring DNS for interface")
|
||||
if err := resetDNS(netIface); err != nil {
|
||||
logger.Warn().Err(err).Msg("something went wrong while restoring DNS")
|
||||
} else {
|
||||
logger.Debug().Str("iface", netIface.Name).Msg("Restoring DNS successfully")
|
||||
}
|
||||
}
|
||||
|
||||
tasks := []task{{s.Uninstall, true}}
|
||||
if doTasks(tasks) {
|
||||
logger.Info().Msg("uninstalled service")
|
||||
}
|
||||
logger.Fatal().Err(uer).Msg("failed to fetch resolver config")
|
||||
}
|
||||
if err != nil {
|
||||
logger.Warn().Err(err).Msg("could not fetch resolver config")
|
||||
return
|
||||
}
|
||||
|
||||
logger.Info().Msg("generating ctrld config from Control-D configuration")
|
||||
if resolverConfig.Ctrld.CustomConfig != "" {
|
||||
logger.Info().Msg("using defined custom config of Control-D resolver")
|
||||
readBase64Config(resolverConfig.Ctrld.CustomConfig)
|
||||
if err := v.Unmarshal(&cfg); err != nil {
|
||||
mainLog.Fatal().Msgf("failed to unmarshal config: %v", err)
|
||||
}
|
||||
for _, listener := range cfg.Listener {
|
||||
if listener.IP == "" {
|
||||
listener.IP = randomLocalIP()
|
||||
}
|
||||
if listener.Port == 0 {
|
||||
listener.Port = 53
|
||||
}
|
||||
}
|
||||
// On router, we want to keep the listener address point to dnsmasq listener, aka 127.0.0.1:53.
|
||||
if router.Name() != "" {
|
||||
if lc := cfg.Listener["0"]; lc != nil {
|
||||
lc.IP = "127.0.0.1"
|
||||
lc.Port = 53
|
||||
}
|
||||
}
|
||||
} else {
|
||||
cfg = ctrld.Config{}
|
||||
cfg.Network = make(map[string]*ctrld.NetworkConfig)
|
||||
cfg.Network["0"] = &ctrld.NetworkConfig{
|
||||
Name: "Network 0",
|
||||
Cidrs: []string{"0.0.0.0/0"},
|
||||
}
|
||||
cfg.Upstream = make(map[string]*ctrld.UpstreamConfig)
|
||||
cfg.Upstream["0"] = &ctrld.UpstreamConfig{
|
||||
Endpoint: resolverConfig.DOH,
|
||||
Type: ctrld.ResolverTypeDOH,
|
||||
Timeout: 5000,
|
||||
}
|
||||
rules := make([]ctrld.Rule, 0, len(resolverConfig.Exclude))
|
||||
for _, domain := range resolverConfig.Exclude {
|
||||
rules = append(rules, ctrld.Rule{domain: []string{}})
|
||||
}
|
||||
cfg.Listener = make(map[string]*ctrld.ListenerConfig)
|
||||
cfg.Listener["0"] = &ctrld.ListenerConfig{
|
||||
IP: "127.0.0.1",
|
||||
Port: 53,
|
||||
Policy: &ctrld.ListenerPolicyConfig{
|
||||
Name: "My Policy",
|
||||
Rules: rules,
|
||||
},
|
||||
}
|
||||
processLogAndCacheFlags()
|
||||
}
|
||||
|
||||
if err := writeConfigFile(); err != nil {
|
||||
logger.Fatal().Err(err).Msg("failed to write config file")
|
||||
} else {
|
||||
logger.Info().Msg("writing config file to: " + defaultConfigFile)
|
||||
}
|
||||
}
|
||||
|
||||
func processListenFlag() {
|
||||
if listenAddress == "" {
|
||||
return
|
||||
}
|
||||
host, portStr, err := net.SplitHostPort(listenAddress)
|
||||
if err != nil {
|
||||
mainLog.Fatal().Msgf("invalid listener address: %v", err)
|
||||
}
|
||||
port, err := strconv.Atoi(portStr)
|
||||
if err != nil {
|
||||
mainLog.Fatal().Msgf("invalid port number: %v", err)
|
||||
}
|
||||
lc := &ctrld.ListenerConfig{
|
||||
IP: host,
|
||||
Port: port,
|
||||
}
|
||||
v.Set("listener", map[string]*ctrld.ListenerConfig{
|
||||
"0": lc,
|
||||
})
|
||||
}
|
||||
|
||||
func processLogAndCacheFlags() {
|
||||
if logPath != "" {
|
||||
cfg.Service.LogLevel = "debug"
|
||||
cfg.Service.LogPath = logPath
|
||||
}
|
||||
|
||||
if cacheSize != 0 {
|
||||
cfg.Service.CacheEnable = true
|
||||
cfg.Service.CacheSize = cacheSize
|
||||
}
|
||||
v.Set("service", cfg.Service)
|
||||
}
|
||||
|
||||
func netInterface(ifaceName string) (*net.Interface, error) {
|
||||
if ifaceName == "auto" {
|
||||
ifaceName = defaultIfaceName()
|
||||
}
|
||||
var iface *net.Interface
|
||||
err := interfaces.ForeachInterface(func(i interfaces.Interface, prefixes []netip.Prefix) {
|
||||
if i.Name == ifaceName {
|
||||
iface = i.Interface
|
||||
}
|
||||
})
|
||||
if iface == nil {
|
||||
return nil, errors.New("interface not found")
|
||||
}
|
||||
if err := patchNetIfaceName(iface); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return iface, err
|
||||
}
|
||||
|
||||
func defaultIfaceName() string {
|
||||
dri, err := interfaces.DefaultRouteInterface()
|
||||
if err != nil {
|
||||
// On WSL 1, the route table does not have any default route. But the fact that
|
||||
// it only uses /etc/resolv.conf for setup DNS, so we can use "lo" here.
|
||||
if oi := osinfo.New(); strings.Contains(oi.String(), "Microsoft") {
|
||||
return "lo"
|
||||
}
|
||||
mainLog.Fatal().Err(err).Msg("failed to get default route interface")
|
||||
}
|
||||
return dri
|
||||
}
|
||||
|
||||
func selfCheckStatus(status service.Status, domain string) service.Status {
|
||||
if domain == "" {
|
||||
// Nothing to do, return the status as-is.
|
||||
return status
|
||||
}
|
||||
c := new(dns.Client)
|
||||
bo := backoff.NewBackoff("self-check", logf, 10*time.Second)
|
||||
bo.LogLongerThan = 500 * time.Millisecond
|
||||
ctx := context.Background()
|
||||
maxAttempts := 20
|
||||
mainLog.Debug().Msg("Performing self-check")
|
||||
var (
|
||||
lcChanged map[string]*ctrld.ListenerConfig
|
||||
mu sync.Mutex
|
||||
)
|
||||
v.OnConfigChange(func(in fsnotify.Event) {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
if err := v.UnmarshalKey("listener", &lcChanged); err != nil {
|
||||
mainLog.Error().Msgf("failed to unmarshal listener config: %v", err)
|
||||
return
|
||||
}
|
||||
})
|
||||
v.WatchConfig()
|
||||
for i := 0; i < maxAttempts; i++ {
|
||||
lc := cfg.Listener["0"]
|
||||
mu.Lock()
|
||||
if lcChanged != nil {
|
||||
lc = lcChanged["0"]
|
||||
}
|
||||
mu.Unlock()
|
||||
m := new(dns.Msg)
|
||||
m.SetQuestion(domain+".", dns.TypeA)
|
||||
m.RecursionDesired = true
|
||||
r, _, err := c.ExchangeContext(ctx, m, net.JoinHostPort(lc.IP, strconv.Itoa(lc.Port)))
|
||||
if r != nil && r.Rcode == dns.RcodeSuccess && len(r.Answer) > 0 {
|
||||
mainLog.Debug().Msgf("self-check against %q succeeded", domain)
|
||||
return status
|
||||
}
|
||||
bo.BackOff(ctx, fmt.Errorf("ExchangeContext: %w", err))
|
||||
}
|
||||
mainLog.Debug().Msgf("self-check against %q failed", domain)
|
||||
return service.StatusUnknown
|
||||
}
|
||||
|
||||
func unsupportedPlatformHelp(cmd *cobra.Command) {
|
||||
mainLog.Error().Msg("Unsupported or incorrectly chosen router platform. Please open an issue and provide all relevant information: https://github.com/Control-D-Inc/ctrld/issues/new")
|
||||
}
|
||||
|
||||
func userHomeDir() (string, error) {
|
||||
switch router.Name() {
|
||||
case router.DDWrt, router.Merlin, router.Tomato:
|
||||
exe, err := os.Executable()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return filepath.Dir(exe), nil
|
||||
}
|
||||
// viper will expand for us.
|
||||
if runtime.GOOS == "windows" {
|
||||
return os.UserHomeDir()
|
||||
}
|
||||
dir := "/etc/controld"
|
||||
if err := os.MkdirAll(dir, 0750); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return dir, nil
|
||||
}
|
||||
|
||||
func tryReadingConfig(writeDefaultConfig bool) {
|
||||
configs := []struct {
|
||||
name string
|
||||
written bool
|
||||
}{
|
||||
// For compatibility, we check for config.toml first, but only read it if exists.
|
||||
{"config", false},
|
||||
{"ctrld", writeDefaultConfig},
|
||||
}
|
||||
|
||||
dir, err := userHomeDir()
|
||||
if err != nil {
|
||||
mainLog.Fatal().Msgf("failed to get config dir: %v", err)
|
||||
}
|
||||
for _, config := range configs {
|
||||
ctrld.SetConfigNameWithPath(v, config.name, dir)
|
||||
v.SetConfigFile(configPath)
|
||||
if readConfigFile(config.written) {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,100 +0,0 @@
|
||||
//go:build linux || freebsd
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"os"
|
||||
"os/exec"
|
||||
"strings"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld/internal/router"
|
||||
)
|
||||
|
||||
func initRouterCLI() {
|
||||
validArgs := append(router.SupportedPlatforms(), "auto")
|
||||
var b strings.Builder
|
||||
b.WriteString("Auto-setup Control D on a router.\n\nSupported platforms:\n\n")
|
||||
for _, arg := range validArgs {
|
||||
b.WriteString(" ₒ ")
|
||||
b.WriteString(arg)
|
||||
if arg == "auto" {
|
||||
b.WriteString(" - detect the platform you are running on")
|
||||
}
|
||||
b.WriteString("\n")
|
||||
}
|
||||
|
||||
routerCmd := &cobra.Command{
|
||||
Use: "setup",
|
||||
Short: b.String(),
|
||||
PreRun: func(cmd *cobra.Command, args []string) {
|
||||
initConsoleLogging()
|
||||
},
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
if len(args) == 0 {
|
||||
_ = cmd.Help()
|
||||
return
|
||||
}
|
||||
if len(args) != 1 {
|
||||
_ = cmd.Help()
|
||||
return
|
||||
}
|
||||
platform := args[0]
|
||||
if platform == "auto" {
|
||||
platform = router.Name()
|
||||
}
|
||||
if !router.IsSupported(platform) {
|
||||
unsupportedPlatformHelp(cmd)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
exe, err := os.Executable()
|
||||
if err != nil {
|
||||
mainLog.Fatal().Msgf("could not find executable path: %v", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
cmdArgs := []string{"start"}
|
||||
cmdArgs = append(cmdArgs, osArgs(platform)...)
|
||||
cmdArgs = append(cmdArgs, "--router")
|
||||
command := exec.Command(exe, cmdArgs...)
|
||||
command.Stdout = os.Stdout
|
||||
command.Stderr = os.Stderr
|
||||
command.Stdin = os.Stdin
|
||||
if err := command.Run(); err != nil {
|
||||
mainLog.Fatal().Msg(err.Error())
|
||||
}
|
||||
},
|
||||
}
|
||||
// Keep these flags in sync with startCmd, except for "--router".
|
||||
routerCmd.Flags().StringVarP(&configPath, "config", "c", "", "Path to config file")
|
||||
routerCmd.Flags().StringVarP(&configBase64, "base64_config", "", "", "Base64 encoded config")
|
||||
routerCmd.Flags().StringVarP(&listenAddress, "listen", "", "", "Listener address and port, in format: address:port")
|
||||
routerCmd.Flags().StringVarP(&primaryUpstream, "primary_upstream", "", "", "Primary upstream endpoint")
|
||||
routerCmd.Flags().StringVarP(&secondaryUpstream, "secondary_upstream", "", "", "Secondary upstream endpoint")
|
||||
routerCmd.Flags().StringSliceVarP(&domains, "domains", "", nil, "List of domain to apply in a split DNS policy")
|
||||
routerCmd.Flags().StringVarP(&logPath, "log", "", "", "Path to log file")
|
||||
routerCmd.Flags().IntVarP(&cacheSize, "cache_size", "", 0, "Enable cache with size items")
|
||||
routerCmd.Flags().StringVarP(&cdUID, "cd", "", "", "Control D resolver uid")
|
||||
routerCmd.Flags().BoolVarP(&cdDev, "dev", "", false, "Use Control D dev resolver/domain")
|
||||
_ = routerCmd.Flags().MarkHidden("dev")
|
||||
routerCmd.Flags().StringVarP(&iface, "iface", "", "", `Update DNS setting for iface, "auto" means the default interface gateway`)
|
||||
|
||||
tmpl := routerCmd.UsageTemplate()
|
||||
tmpl = strings.Replace(tmpl, "{{.UseLine}}", "{{.UseLine}} [platform]", 1)
|
||||
routerCmd.SetUsageTemplate(tmpl)
|
||||
rootCmd.AddCommand(routerCmd)
|
||||
}
|
||||
|
||||
func osArgs(platform string) []string {
|
||||
args := os.Args[2:]
|
||||
n := 0
|
||||
for _, x := range args {
|
||||
if x != platform && x != "auto" {
|
||||
args[n] = x
|
||||
n++
|
||||
}
|
||||
}
|
||||
return args[:n]
|
||||
}
|
||||
@@ -1,5 +0,0 @@
|
||||
//go:build !linux && !freebsd
|
||||
|
||||
package main
|
||||
|
||||
func initRouterCLI() {}
|
||||
@@ -1,23 +0,0 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func Test_writeConfigFile(t *testing.T) {
|
||||
tmpdir := t.TempDir()
|
||||
// simulate --config CLI flag by setting configPath manually.
|
||||
configPath = filepath.Join(tmpdir, "ctrld.toml")
|
||||
_, err := os.Stat(configPath)
|
||||
assert.True(t, os.IsNotExist(err))
|
||||
|
||||
assert.NoError(t, writeConfigFile())
|
||||
|
||||
_, err = os.Stat(configPath)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
@@ -1,519 +0,0 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"net"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"golang.org/x/sync/errgroup"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld"
|
||||
"github.com/Control-D-Inc/ctrld/internal/dnscache"
|
||||
ctrldnet "github.com/Control-D-Inc/ctrld/internal/net"
|
||||
"github.com/Control-D-Inc/ctrld/internal/router"
|
||||
)
|
||||
|
||||
const (
|
||||
staleTTL = 60 * time.Second
|
||||
// EDNS0_OPTION_MAC is dnsmasq EDNS0 code for adding mac option.
|
||||
// https://thekelleys.org.uk/gitweb/?p=dnsmasq.git;a=blob;f=src/dns-protocol.h;h=76ac66a8c28317e9c121a74ab5fd0e20f6237dc8;hb=HEAD#l81
|
||||
// This is also dns.EDNS0LOCALSTART, but define our own constant here for clarification.
|
||||
EDNS0_OPTION_MAC = 0xFDE9
|
||||
)
|
||||
|
||||
var osUpstreamConfig = &ctrld.UpstreamConfig{
|
||||
Name: "OS resolver",
|
||||
Type: ctrld.ResolverTypeOS,
|
||||
Timeout: 2000,
|
||||
}
|
||||
|
||||
func (p *prog) serveDNS(listenerNum string) error {
|
||||
listenerConfig := p.cfg.Listener[listenerNum]
|
||||
// make sure ip is allocated
|
||||
if allocErr := p.allocateIP(listenerConfig.IP); allocErr != nil {
|
||||
mainLog.Error().Err(allocErr).Str("ip", listenerConfig.IP).Msg("serveUDP: failed to allocate listen ip")
|
||||
return allocErr
|
||||
}
|
||||
var failoverRcodes []int
|
||||
if listenerConfig.Policy != nil {
|
||||
failoverRcodes = listenerConfig.Policy.FailoverRcodeNumbers
|
||||
}
|
||||
handler := dns.HandlerFunc(func(w dns.ResponseWriter, m *dns.Msg) {
|
||||
q := m.Question[0]
|
||||
domain := canonicalName(q.Name)
|
||||
reqId := requestID()
|
||||
remoteAddr := spoofRemoteAddr(w.RemoteAddr(), router.GetClientInfoByMac(macFromMsg(m)))
|
||||
fmtSrcToDest := fmtRemoteToLocal(listenerNum, remoteAddr.String(), w.LocalAddr().String())
|
||||
t := time.Now()
|
||||
ctx := context.WithValue(context.Background(), ctrld.ReqIdCtxKey{}, reqId)
|
||||
ctrld.Log(ctx, mainLog.Debug(), "%s received query: %s %s", fmtSrcToDest, dns.TypeToString[q.Qtype], domain)
|
||||
upstreams, matched := p.upstreamFor(ctx, listenerNum, listenerConfig, remoteAddr, domain)
|
||||
var answer *dns.Msg
|
||||
if !matched && listenerConfig.Restricted {
|
||||
answer = new(dns.Msg)
|
||||
answer.SetRcode(m, dns.RcodeRefused)
|
||||
|
||||
} else {
|
||||
answer = p.proxy(ctx, upstreams, failoverRcodes, m)
|
||||
rtt := time.Since(t)
|
||||
ctrld.Log(ctx, mainLog.Debug(), "received response of %d bytes in %s", answer.Len(), rtt)
|
||||
}
|
||||
if err := w.WriteMsg(answer); err != nil {
|
||||
ctrld.Log(ctx, mainLog.Error().Err(err), "serveUDP: failed to send DNS response to client")
|
||||
}
|
||||
})
|
||||
|
||||
g, ctx := errgroup.WithContext(context.Background())
|
||||
for _, proto := range []string{"udp", "tcp"} {
|
||||
proto := proto
|
||||
if needLocalIPv6Listener() {
|
||||
g.Go(func() error {
|
||||
s, errCh := runDNSServer(net.JoinHostPort("::1", strconv.Itoa(listenerConfig.Port)), proto, handler)
|
||||
defer s.Shutdown()
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
case err := <-errCh:
|
||||
// Local ipv6 listener should not terminate ctrld.
|
||||
// It's a workaround for a quirk on Windows.
|
||||
mainLog.Warn().Err(err).Msg("local ipv6 listener failed")
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
g.Go(func() error {
|
||||
s, errCh := runDNSServer(dnsListenAddress(listenerNum, listenerConfig), proto, handler)
|
||||
defer s.Shutdown()
|
||||
if listenerConfig.Port == 0 {
|
||||
switch s.Net {
|
||||
case "udp":
|
||||
mainLog.Info().Msgf("Random port chosen for udp listener.%s: %s", listenerNum, s.PacketConn.LocalAddr())
|
||||
case "tcp":
|
||||
mainLog.Info().Msgf("Random port chosen for tcp listener.%s: %s", listenerNum, s.Listener.Addr())
|
||||
}
|
||||
}
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil
|
||||
case err := <-errCh:
|
||||
return err
|
||||
}
|
||||
})
|
||||
}
|
||||
return g.Wait()
|
||||
}
|
||||
|
||||
// upstreamFor returns the list of upstreams for resolving the given domain,
|
||||
// matching by policies defined in the listener config. The second return value
|
||||
// reports whether the domain matches the policy.
|
||||
//
|
||||
// Though domain policy has higher priority than network policy, it is still
|
||||
// processed later, because policy logging want to know whether a network rule
|
||||
// is disregarded in favor of the domain level rule.
|
||||
func (p *prog) upstreamFor(ctx context.Context, defaultUpstreamNum string, lc *ctrld.ListenerConfig, addr net.Addr, domain string) ([]string, bool) {
|
||||
upstreams := []string{"upstream." + defaultUpstreamNum}
|
||||
matchedPolicy := "no policy"
|
||||
matchedNetwork := "no network"
|
||||
matchedRule := "no rule"
|
||||
matched := false
|
||||
|
||||
defer func() {
|
||||
if !matched && lc.Restricted {
|
||||
ctrld.Log(ctx, mainLog.Info(), "query refused, %s does not match any network policy", addr.String())
|
||||
return
|
||||
}
|
||||
ctrld.Log(ctx, mainLog.Info(), "%s, %s, %s -> %v", matchedPolicy, matchedNetwork, matchedRule, upstreams)
|
||||
}()
|
||||
|
||||
if lc.Policy == nil {
|
||||
return upstreams, false
|
||||
}
|
||||
|
||||
do := func(policyUpstreams []string) {
|
||||
upstreams = append([]string(nil), policyUpstreams...)
|
||||
}
|
||||
|
||||
var networkTargets []string
|
||||
var sourceIP net.IP
|
||||
switch addr := addr.(type) {
|
||||
case *net.UDPAddr:
|
||||
sourceIP = addr.IP
|
||||
case *net.TCPAddr:
|
||||
sourceIP = addr.IP
|
||||
}
|
||||
|
||||
networkRules:
|
||||
for _, rule := range lc.Policy.Networks {
|
||||
for source, targets := range rule {
|
||||
networkNum := strings.TrimPrefix(source, "network.")
|
||||
nc := p.cfg.Network[networkNum]
|
||||
if nc == nil {
|
||||
continue
|
||||
}
|
||||
for _, ipNet := range nc.IPNets {
|
||||
if ipNet.Contains(sourceIP) {
|
||||
matchedPolicy = lc.Policy.Name
|
||||
matchedNetwork = source
|
||||
networkTargets = targets
|
||||
matched = true
|
||||
break networkRules
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, rule := range lc.Policy.Rules {
|
||||
// There's only one entry per rule, config validation ensures this.
|
||||
for source, targets := range rule {
|
||||
if source == domain || wildcardMatches(source, domain) {
|
||||
matchedPolicy = lc.Policy.Name
|
||||
if len(networkTargets) > 0 {
|
||||
matchedNetwork += " (unenforced)"
|
||||
}
|
||||
matchedRule = source
|
||||
do(targets)
|
||||
matched = true
|
||||
return upstreams, matched
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if matched {
|
||||
do(networkTargets)
|
||||
}
|
||||
|
||||
return upstreams, matched
|
||||
}
|
||||
|
||||
func (p *prog) proxy(ctx context.Context, upstreams []string, failoverRcodes []int, msg *dns.Msg) *dns.Msg {
|
||||
var staleAnswer *dns.Msg
|
||||
serveStaleCache := p.cache != nil && p.cfg.Service.CacheServeStale
|
||||
upstreamConfigs := p.upstreamConfigsFromUpstreamNumbers(upstreams)
|
||||
if len(upstreamConfigs) == 0 {
|
||||
upstreamConfigs = []*ctrld.UpstreamConfig{osUpstreamConfig}
|
||||
upstreams = []string{"upstream.os"}
|
||||
}
|
||||
// Inverse query should not be cached: https://www.rfc-editor.org/rfc/rfc1035#section-7.4
|
||||
if p.cache != nil && msg.Question[0].Qtype != dns.TypePTR {
|
||||
for _, upstream := range upstreams {
|
||||
cachedValue := p.cache.Get(dnscache.NewKey(msg, upstream))
|
||||
if cachedValue == nil {
|
||||
continue
|
||||
}
|
||||
answer := cachedValue.Msg.Copy()
|
||||
answer.SetRcode(msg, answer.Rcode)
|
||||
now := time.Now()
|
||||
if cachedValue.Expire.After(now) {
|
||||
ctrld.Log(ctx, mainLog.Debug(), "hit cached response")
|
||||
setCachedAnswerTTL(answer, now, cachedValue.Expire)
|
||||
return answer
|
||||
}
|
||||
staleAnswer = answer
|
||||
}
|
||||
}
|
||||
resolve1 := func(n int, upstreamConfig *ctrld.UpstreamConfig, msg *dns.Msg) (*dns.Msg, error) {
|
||||
ctrld.Log(ctx, mainLog.Debug(), "sending query to %s: %s", upstreams[n], upstreamConfig.Name)
|
||||
dnsResolver, err := ctrld.NewResolver(upstreamConfig)
|
||||
if err != nil {
|
||||
ctrld.Log(ctx, mainLog.Error().Err(err), "failed to create resolver")
|
||||
return nil, err
|
||||
}
|
||||
resolveCtx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
if upstreamConfig.Timeout > 0 {
|
||||
timeoutCtx, cancel := context.WithTimeout(resolveCtx, time.Millisecond*time.Duration(upstreamConfig.Timeout))
|
||||
defer cancel()
|
||||
resolveCtx = timeoutCtx
|
||||
}
|
||||
return dnsResolver.Resolve(resolveCtx, msg)
|
||||
}
|
||||
resolve := func(n int, upstreamConfig *ctrld.UpstreamConfig, msg *dns.Msg) *dns.Msg {
|
||||
if upstreamConfig.UpstreamSendClientInfo() {
|
||||
ci := router.GetClientInfoByMac(macFromMsg(msg))
|
||||
if ci != nil {
|
||||
ctrld.Log(ctx, mainLog.Debug(), "including client info with the request")
|
||||
ctx = context.WithValue(ctx, ctrld.ClientInfoCtxKey{}, ci)
|
||||
}
|
||||
}
|
||||
answer, err := resolve1(n, upstreamConfig, msg)
|
||||
// Only do re-bootstrapping if bootstrap ip is not explicitly set by user.
|
||||
if err != nil && upstreamConfig.BootstrapIP == "" {
|
||||
ctrld.Log(ctx, mainLog.Debug().Err(err), "could not resolve query on first attempt, retrying...")
|
||||
// If any error occurred, re-bootstrap transport/ip, retry the request.
|
||||
upstreamConfig.ReBootstrap()
|
||||
answer, err = resolve1(n, upstreamConfig, msg)
|
||||
if err == nil {
|
||||
return answer
|
||||
}
|
||||
ctrld.Log(ctx, mainLog.Error().Err(err), "failed to resolve query")
|
||||
return nil
|
||||
}
|
||||
return answer
|
||||
}
|
||||
for n, upstreamConfig := range upstreamConfigs {
|
||||
if upstreamConfig == nil {
|
||||
continue
|
||||
}
|
||||
answer := resolve(n, upstreamConfig, msg)
|
||||
if answer == nil {
|
||||
if serveStaleCache && staleAnswer != nil {
|
||||
ctrld.Log(ctx, mainLog.Debug(), "serving stale cached response")
|
||||
now := time.Now()
|
||||
setCachedAnswerTTL(staleAnswer, now, now.Add(staleTTL))
|
||||
return staleAnswer
|
||||
}
|
||||
continue
|
||||
}
|
||||
if answer.Rcode != dns.RcodeSuccess && len(upstreamConfigs) > 1 && containRcode(failoverRcodes, answer.Rcode) {
|
||||
ctrld.Log(ctx, mainLog.Debug(), "failover rcode matched, process to next upstream")
|
||||
continue
|
||||
}
|
||||
|
||||
// set compression, as it is not set by default when unpacking
|
||||
answer.Compress = true
|
||||
|
||||
if p.cache != nil {
|
||||
ttl := ttlFromMsg(answer)
|
||||
now := time.Now()
|
||||
expired := now.Add(time.Duration(ttl) * time.Second)
|
||||
if cachedTTL := p.cfg.Service.CacheTTLOverride; cachedTTL > 0 {
|
||||
expired = now.Add(time.Duration(cachedTTL) * time.Second)
|
||||
}
|
||||
setCachedAnswerTTL(answer, now, expired)
|
||||
p.cache.Add(dnscache.NewKey(msg, upstreams[n]), dnscache.NewValue(answer, expired))
|
||||
ctrld.Log(ctx, mainLog.Debug(), "add cached response")
|
||||
}
|
||||
return answer
|
||||
}
|
||||
ctrld.Log(ctx, mainLog.Error(), "all upstreams failed")
|
||||
answer := new(dns.Msg)
|
||||
answer.SetRcode(msg, dns.RcodeServerFailure)
|
||||
return answer
|
||||
}
|
||||
|
||||
func (p *prog) upstreamConfigsFromUpstreamNumbers(upstreams []string) []*ctrld.UpstreamConfig {
|
||||
upstreamConfigs := make([]*ctrld.UpstreamConfig, 0, len(upstreams))
|
||||
for _, upstream := range upstreams {
|
||||
upstreamNum := strings.TrimPrefix(upstream, "upstream.")
|
||||
upstreamConfigs = append(upstreamConfigs, p.cfg.Upstream[upstreamNum])
|
||||
}
|
||||
return upstreamConfigs
|
||||
}
|
||||
|
||||
// canonicalName returns canonical name from FQDN with "." trimmed.
|
||||
func canonicalName(fqdn string) string {
|
||||
q := strings.TrimSpace(fqdn)
|
||||
q = strings.TrimSuffix(q, ".")
|
||||
// https://datatracker.ietf.org/doc/html/rfc4343
|
||||
q = strings.ToLower(q)
|
||||
|
||||
return q
|
||||
}
|
||||
|
||||
func wildcardMatches(wildcard, domain string) bool {
|
||||
// Wildcard match.
|
||||
wildCardParts := strings.Split(wildcard, "*")
|
||||
if len(wildCardParts) != 2 {
|
||||
return false
|
||||
}
|
||||
|
||||
switch {
|
||||
case len(wildCardParts[0]) > 0 && len(wildCardParts[1]) > 0:
|
||||
// Domain must match both prefix and suffix.
|
||||
return strings.HasPrefix(domain, wildCardParts[0]) && strings.HasSuffix(domain, wildCardParts[1])
|
||||
|
||||
case len(wildCardParts[1]) > 0:
|
||||
// Only suffix must match.
|
||||
return strings.HasSuffix(domain, wildCardParts[1])
|
||||
|
||||
case len(wildCardParts[0]) > 0:
|
||||
// Only prefix must match.
|
||||
return strings.HasPrefix(domain, wildCardParts[0])
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func fmtRemoteToLocal(listenerNum, remote, local string) string {
|
||||
return fmt.Sprintf("%s -> listener.%s: %s:", remote, listenerNum, local)
|
||||
}
|
||||
|
||||
func requestID() string {
|
||||
b := make([]byte, 3) // 6 chars
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return hex.EncodeToString(b)
|
||||
}
|
||||
|
||||
func containRcode(rcodes []int, rcode int) bool {
|
||||
for i := range rcodes {
|
||||
if rcodes[i] == rcode {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func setCachedAnswerTTL(answer *dns.Msg, now, expiredTime time.Time) {
|
||||
ttlSecs := expiredTime.Sub(now).Seconds()
|
||||
if ttlSecs < 0 {
|
||||
return
|
||||
}
|
||||
|
||||
ttl := uint32(ttlSecs)
|
||||
for _, rr := range answer.Answer {
|
||||
rr.Header().Ttl = ttl
|
||||
}
|
||||
for _, rr := range answer.Ns {
|
||||
rr.Header().Ttl = ttl
|
||||
}
|
||||
for _, rr := range answer.Extra {
|
||||
if rr.Header().Rrtype != dns.TypeOPT {
|
||||
rr.Header().Ttl = ttl
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func ttlFromMsg(msg *dns.Msg) uint32 {
|
||||
for _, rr := range msg.Answer {
|
||||
return rr.Header().Ttl
|
||||
}
|
||||
for _, rr := range msg.Ns {
|
||||
return rr.Header().Ttl
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func needLocalIPv6Listener() bool {
|
||||
// On Windows, there's no easy way for disabling/removing IPv6 DNS resolver, so we check whether we can
|
||||
// listen on ::1, then spawn a listener for receiving DNS requests.
|
||||
return ctrldnet.SupportsIPv6ListenLocal() && runtime.GOOS == "windows"
|
||||
}
|
||||
|
||||
func dnsListenAddress(lcNum string, lc *ctrld.ListenerConfig) string {
|
||||
if addr := router.ListenAddress(); setupRouter && addr != "" && lcNum == "0" {
|
||||
return addr
|
||||
}
|
||||
return net.JoinHostPort(lc.IP, strconv.Itoa(lc.Port))
|
||||
}
|
||||
|
||||
func macFromMsg(msg *dns.Msg) string {
|
||||
if opt := msg.IsEdns0(); opt != nil {
|
||||
for _, s := range opt.Option {
|
||||
switch e := s.(type) {
|
||||
case *dns.EDNS0_LOCAL:
|
||||
if e.Code == EDNS0_OPTION_MAC {
|
||||
return net.HardwareAddr(e.Data).String()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func spoofRemoteAddr(addr net.Addr, ci *ctrld.ClientInfo) net.Addr {
|
||||
if ci != nil && ci.IP != "" {
|
||||
switch addr := addr.(type) {
|
||||
case *net.UDPAddr:
|
||||
udpAddr := &net.UDPAddr{
|
||||
IP: net.ParseIP(ci.IP),
|
||||
Port: addr.Port,
|
||||
Zone: addr.Zone,
|
||||
}
|
||||
return udpAddr
|
||||
case *net.TCPAddr:
|
||||
udpAddr := &net.TCPAddr{
|
||||
IP: net.ParseIP(ci.IP),
|
||||
Port: addr.Port,
|
||||
Zone: addr.Zone,
|
||||
}
|
||||
return udpAddr
|
||||
}
|
||||
}
|
||||
return addr
|
||||
}
|
||||
|
||||
// runDNSServer starts a DNS server for given address and network,
|
||||
// with the given handler. It ensures the server has started listening.
|
||||
// Any error will be reported to the caller via returned channel.
|
||||
//
|
||||
// It's the caller responsibility to call Shutdown to close the server.
|
||||
func runDNSServer(addr, network string, handler dns.Handler) (*dns.Server, <-chan error) {
|
||||
s := &dns.Server{
|
||||
Addr: addr,
|
||||
Net: network,
|
||||
Handler: handler,
|
||||
}
|
||||
|
||||
waitLock := sync.Mutex{}
|
||||
waitLock.Lock()
|
||||
s.NotifyStartedFunc = waitLock.Unlock
|
||||
|
||||
errCh := make(chan error)
|
||||
go func() {
|
||||
defer close(errCh)
|
||||
if err := s.ListenAndServe(); err != nil {
|
||||
waitLock.Unlock()
|
||||
mainLog.Error().Err(err).Msgf("could not listen and serve on: %s", s.Addr)
|
||||
errCh <- err
|
||||
}
|
||||
}()
|
||||
waitLock.Lock()
|
||||
return s, errCh
|
||||
}
|
||||
|
||||
// runDNSServerForNTPD starts a DNS server listening on router.ListenAddress(). It must only be called when ctrld
|
||||
// running on router, before router.PreRun() to serve DNS request for NTP synchronization. The caller must call
|
||||
// s.Shutdown() explicitly when NTP is synced successfully.
|
||||
func runDNSServerForNTPD(addr string) (*dns.Server, <-chan error) {
|
||||
if addr == "" {
|
||||
return &dns.Server{}, nil
|
||||
}
|
||||
dnsResolver := ctrld.NewBootstrapResolver()
|
||||
s := &dns.Server{
|
||||
Addr: addr,
|
||||
Net: "udp",
|
||||
Handler: dns.HandlerFunc(func(w dns.ResponseWriter, m *dns.Msg) {
|
||||
mainLog.Debug().Msg("Serving query for ntpd")
|
||||
resolveCtx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
if osUpstreamConfig.Timeout > 0 {
|
||||
timeoutCtx, cancel := context.WithTimeout(resolveCtx, time.Millisecond*time.Duration(osUpstreamConfig.Timeout))
|
||||
defer cancel()
|
||||
resolveCtx = timeoutCtx
|
||||
}
|
||||
answer, err := dnsResolver.Resolve(resolveCtx, m)
|
||||
if err != nil {
|
||||
mainLog.Error().Err(err).Msgf("could not resolve: %v", m)
|
||||
return
|
||||
}
|
||||
if err := w.WriteMsg(answer); err != nil {
|
||||
mainLog.Error().Err(err).Msg("runDNSServerForNTPD: failed to send DNS response")
|
||||
}
|
||||
}),
|
||||
}
|
||||
|
||||
waitLock := sync.Mutex{}
|
||||
waitLock.Lock()
|
||||
s.NotifyStartedFunc = waitLock.Unlock
|
||||
|
||||
errCh := make(chan error)
|
||||
go func() {
|
||||
defer close(errCh)
|
||||
if err := s.ListenAndServe(); err != nil {
|
||||
waitLock.Unlock()
|
||||
mainLog.Error().Err(err).Msgf("could not listen and serve on: %s", s.Addr)
|
||||
errCh <- err
|
||||
}
|
||||
}()
|
||||
waitLock.Lock()
|
||||
return s, errCh
|
||||
}
|
||||
@@ -1,218 +0,0 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld"
|
||||
"github.com/Control-D-Inc/ctrld/internal/dnscache"
|
||||
"github.com/Control-D-Inc/ctrld/testhelper"
|
||||
)
|
||||
|
||||
func Test_wildcardMatches(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
wildcard string
|
||||
domain string
|
||||
match bool
|
||||
}{
|
||||
{"prefix parent should not match", "*.windscribe.com", "windscribe.com", false},
|
||||
{"prefix", "*.windscribe.com", "anything.windscribe.com", true},
|
||||
{"prefix not match other domain", "*.windscribe.com", "example.com", false},
|
||||
{"prefix not match domain in name", "*.windscribe.com", "wwindscribe.com", false},
|
||||
{"suffix", "suffix.*", "suffix.windscribe.com", true},
|
||||
{"suffix not match other", "suffix.*", "suffix1.windscribe.com", false},
|
||||
{"both", "suffix.*.windscribe.com", "suffix.anything.windscribe.com", true},
|
||||
{"both not match", "suffix.*.windscribe.com", "suffix1.suffix.windscribe.com", false},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
if got := wildcardMatches(tc.wildcard, tc.domain); got != tc.match {
|
||||
t.Errorf("unexpected result, wildcard: %s, domain: %s, want: %v, got: %v", tc.wildcard, tc.domain, tc.match, got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_canonicalName(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
domain string
|
||||
canonical string
|
||||
}{
|
||||
{"fqdn to canonical", "windscribe.com.", "windscribe.com"},
|
||||
{"already canonical", "windscribe.com", "windscribe.com"},
|
||||
{"case insensitive", "Windscribe.Com.", "windscribe.com"},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
if got := canonicalName(tc.domain); got != tc.canonical {
|
||||
t.Errorf("unexpected result, want: %s, got: %s", tc.canonical, got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_prog_upstreamFor(t *testing.T) {
|
||||
cfg := testhelper.SampleConfig(t)
|
||||
prog := &prog{cfg: cfg}
|
||||
for _, nc := range prog.cfg.Network {
|
||||
for _, cidr := range nc.Cidrs {
|
||||
_, ipNet, err := net.ParseCIDR(cidr)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
nc.IPNets = append(nc.IPNets, ipNet)
|
||||
}
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
ip string
|
||||
defaultUpstreamNum string
|
||||
lc *ctrld.ListenerConfig
|
||||
domain string
|
||||
upstreams []string
|
||||
matched bool
|
||||
testLogMsg string
|
||||
}{
|
||||
{"Policy map matches", "192.168.0.1:0", "0", prog.cfg.Listener["0"], "abc.xyz", []string{"upstream.1", "upstream.0"}, true, ""},
|
||||
{"Policy split matches", "192.168.0.1:0", "0", prog.cfg.Listener["0"], "abc.ru", []string{"upstream.1"}, true, ""},
|
||||
{"Policy map for other network matches", "192.168.1.2:0", "0", prog.cfg.Listener["0"], "abc.xyz", []string{"upstream.0"}, true, ""},
|
||||
{"No policy map for listener", "192.168.1.2:0", "1", prog.cfg.Listener["1"], "abc.ru", []string{"upstream.1"}, false, ""},
|
||||
{"unenforced loging", "192.168.1.2:0", "0", prog.cfg.Listener["0"], "abc.ru", []string{"upstream.1"}, true, "My Policy, network.1 (unenforced), *.ru -> [upstream.1]"},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
for _, network := range []string{"udp", "tcp"} {
|
||||
var (
|
||||
addr net.Addr
|
||||
err error
|
||||
)
|
||||
switch network {
|
||||
case "udp":
|
||||
addr, err = net.ResolveUDPAddr(network, tc.ip)
|
||||
case "tcp":
|
||||
addr, err = net.ResolveTCPAddr(network, tc.ip)
|
||||
}
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, addr)
|
||||
ctx := context.WithValue(context.Background(), ctrld.ReqIdCtxKey{}, requestID())
|
||||
upstreams, matched := prog.upstreamFor(ctx, tc.defaultUpstreamNum, tc.lc, addr, tc.domain)
|
||||
assert.Equal(t, tc.matched, matched)
|
||||
assert.Equal(t, tc.upstreams, upstreams)
|
||||
if tc.testLogMsg != "" {
|
||||
assert.Contains(t, logOutput.String(), tc.testLogMsg)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCache(t *testing.T) {
|
||||
cfg := testhelper.SampleConfig(t)
|
||||
prog := &prog{cfg: cfg}
|
||||
for _, nc := range prog.cfg.Network {
|
||||
for _, cidr := range nc.Cidrs {
|
||||
_, ipNet, err := net.ParseCIDR(cidr)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
nc.IPNets = append(nc.IPNets, ipNet)
|
||||
}
|
||||
}
|
||||
cacher, err := dnscache.NewLRUCache(4096)
|
||||
require.NoError(t, err)
|
||||
prog.cache = cacher
|
||||
|
||||
msg := new(dns.Msg)
|
||||
msg.SetQuestion("example.com", dns.TypeA)
|
||||
msg.MsgHdr.RecursionDesired = true
|
||||
answer1 := new(dns.Msg)
|
||||
answer1.SetRcode(msg, dns.RcodeSuccess)
|
||||
|
||||
prog.cache.Add(dnscache.NewKey(msg, "upstream.1"), dnscache.NewValue(answer1, time.Now().Add(time.Minute)))
|
||||
answer2 := new(dns.Msg)
|
||||
answer2.SetRcode(msg, dns.RcodeRefused)
|
||||
prog.cache.Add(dnscache.NewKey(msg, "upstream.0"), dnscache.NewValue(answer2, time.Now().Add(time.Minute)))
|
||||
|
||||
got1 := prog.proxy(context.Background(), []string{"upstream.1"}, nil, msg)
|
||||
got2 := prog.proxy(context.Background(), []string{"upstream.0"}, nil, msg)
|
||||
assert.NotSame(t, got1, got2)
|
||||
assert.Equal(t, answer1.Rcode, got1.Rcode)
|
||||
assert.Equal(t, answer2.Rcode, got2.Rcode)
|
||||
}
|
||||
|
||||
func Test_macFromMsg(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
mac string
|
||||
wantMac bool
|
||||
}{
|
||||
{"has mac", "4c:20:b8:ab:87:1b", true},
|
||||
{"no mac", "4c:20:b8:ab:87:1b", false},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
hw, err := net.ParseMAC(tc.mac)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
m := new(dns.Msg)
|
||||
m.SetQuestion("example.com.", dns.TypeA)
|
||||
o := &dns.OPT{Hdr: dns.RR_Header{Name: ".", Rrtype: dns.TypeOPT}}
|
||||
if tc.wantMac {
|
||||
ec1 := &dns.EDNS0_LOCAL{Code: EDNS0_OPTION_MAC, Data: hw}
|
||||
o.Option = append(o.Option, ec1)
|
||||
}
|
||||
m.Extra = append(m.Extra, o)
|
||||
got := macFromMsg(m)
|
||||
if tc.wantMac && got != tc.mac {
|
||||
t.Errorf("mismatch, want: %q, got: %q", tc.mac, got)
|
||||
}
|
||||
if !tc.wantMac && got != "" {
|
||||
t.Errorf("unexpected mac: %q", got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_remoteAddrFromMsg(t *testing.T) {
|
||||
loopbackIP := net.ParseIP("127.0.0.1")
|
||||
tests := []struct {
|
||||
name string
|
||||
addr net.Addr
|
||||
ci *ctrld.ClientInfo
|
||||
want string
|
||||
}{
|
||||
{"tcp", &net.TCPAddr{IP: loopbackIP, Port: 12345}, &ctrld.ClientInfo{IP: "192.168.1.10"}, "192.168.1.10:12345"},
|
||||
{"udp", &net.UDPAddr{IP: loopbackIP, Port: 12345}, &ctrld.ClientInfo{IP: "192.168.1.11"}, "192.168.1.11:12345"},
|
||||
{"nil client info", &net.UDPAddr{IP: loopbackIP, Port: 12345}, nil, "127.0.0.1:12345"},
|
||||
{"empty ip", &net.UDPAddr{IP: loopbackIP, Port: 12345}, &ctrld.ClientInfo{}, "127.0.0.1:12345"},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
addr := spoofRemoteAddr(tc.addr, tc.ci)
|
||||
if addr.String() != tc.want {
|
||||
t.Errorf("unexpected result, want: %q, got: %q", tc.want, addr.String())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,135 +1,13 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"github.com/kardianos/service"
|
||||
"github.com/rs/zerolog"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld"
|
||||
)
|
||||
|
||||
var (
|
||||
configPath string
|
||||
configBase64 string
|
||||
daemon bool
|
||||
listenAddress string
|
||||
primaryUpstream string
|
||||
secondaryUpstream string
|
||||
domains []string
|
||||
logPath string
|
||||
homedir string
|
||||
cacheSize int
|
||||
cfg ctrld.Config
|
||||
verbose int
|
||||
silent bool
|
||||
cdUID string
|
||||
cdDev bool
|
||||
iface string
|
||||
ifaceStartStop string
|
||||
setupRouter bool
|
||||
|
||||
mainLog = zerolog.New(io.Discard)
|
||||
consoleWriter zerolog.ConsoleWriter
|
||||
"github.com/Control-D-Inc/ctrld/cmd/cli"
|
||||
)
|
||||
|
||||
func main() {
|
||||
ctrld.InitConfig(v, "ctrld")
|
||||
initCLI()
|
||||
initRouterCLI()
|
||||
if err := rootCmd.Execute(); err != nil {
|
||||
mainLog.Error().Msg(err.Error())
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
func normalizeLogFilePath(logFilePath string) string {
|
||||
if logFilePath == "" || filepath.IsAbs(logFilePath) || service.Interactive() {
|
||||
return logFilePath
|
||||
}
|
||||
if homedir != "" {
|
||||
return filepath.Join(homedir, logFilePath)
|
||||
}
|
||||
dir, _ := userHomeDir()
|
||||
if dir == "" {
|
||||
return logFilePath
|
||||
}
|
||||
return filepath.Join(dir, logFilePath)
|
||||
}
|
||||
|
||||
func initConsoleLogging() {
|
||||
consoleWriter = zerolog.NewConsoleWriter(func(w *zerolog.ConsoleWriter) {
|
||||
w.TimeFormat = time.StampMilli
|
||||
})
|
||||
multi := zerolog.MultiLevelWriter(consoleWriter)
|
||||
mainLog = mainLog.Output(multi).With().Timestamp().Logger()
|
||||
switch {
|
||||
case silent:
|
||||
zerolog.SetGlobalLevel(zerolog.NoLevel)
|
||||
case verbose == 1:
|
||||
zerolog.SetGlobalLevel(zerolog.InfoLevel)
|
||||
case verbose > 1:
|
||||
zerolog.SetGlobalLevel(zerolog.DebugLevel)
|
||||
default:
|
||||
zerolog.SetGlobalLevel(zerolog.NoticeLevel)
|
||||
}
|
||||
}
|
||||
|
||||
func initLogging() {
|
||||
writers := []io.Writer{io.Discard}
|
||||
if logFilePath := normalizeLogFilePath(cfg.Service.LogPath); logFilePath != "" {
|
||||
// Create parent directory if necessary.
|
||||
if err := os.MkdirAll(filepath.Dir(logFilePath), 0750); err != nil {
|
||||
mainLog.Error().Msgf("failed to create log path: %v", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
// Backup old log file with .1 suffix.
|
||||
if err := os.Rename(logFilePath, logFilePath+".1"); err != nil && !os.IsNotExist(err) {
|
||||
mainLog.Error().Msgf("could not backup old log file: %v", err)
|
||||
}
|
||||
logFile, err := os.OpenFile(logFilePath, os.O_CREATE|os.O_RDWR, os.FileMode(0o600))
|
||||
if err != nil {
|
||||
mainLog.Error().Msgf("failed to create log file: %v", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
writers = append(writers, logFile)
|
||||
}
|
||||
writers = append(writers, consoleWriter)
|
||||
multi := zerolog.MultiLevelWriter(writers...)
|
||||
mainLog = mainLog.Output(multi).With().Timestamp().Logger()
|
||||
// TODO: find a better way.
|
||||
ctrld.ProxyLog = mainLog
|
||||
|
||||
zerolog.SetGlobalLevel(zerolog.NoticeLevel)
|
||||
logLevel := cfg.Service.LogLevel
|
||||
switch {
|
||||
case silent:
|
||||
zerolog.SetGlobalLevel(zerolog.NoLevel)
|
||||
return
|
||||
case verbose == 1:
|
||||
logLevel = "info"
|
||||
case verbose > 1:
|
||||
logLevel = "debug"
|
||||
}
|
||||
if logLevel == "" {
|
||||
return
|
||||
}
|
||||
level, err := zerolog.ParseLevel(logLevel)
|
||||
if err != nil {
|
||||
mainLog.Warn().Err(err).Msg("could not set log level")
|
||||
return
|
||||
}
|
||||
zerolog.SetGlobalLevel(level)
|
||||
}
|
||||
|
||||
func initCache() {
|
||||
if !cfg.Service.CacheEnable {
|
||||
return
|
||||
}
|
||||
if cfg.Service.CacheSize == 0 {
|
||||
cfg.Service.CacheSize = 4096
|
||||
}
|
||||
cli.Main()
|
||||
// make sure we exit with 0 if there are no errors
|
||||
os.Exit(0)
|
||||
}
|
||||
|
||||
@@ -1,44 +0,0 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"io"
|
||||
"net"
|
||||
"os/exec"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func patchNetIfaceName(iface *net.Interface) error {
|
||||
b, err := exec.Command("networksetup", "-listnetworkserviceorder").Output()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if name := networkServiceName(iface.Name, bytes.NewReader(b)); name != "" {
|
||||
iface.Name = name
|
||||
mainLog.Debug().Str("network_service", name).Msg("found network service name for interface")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func networkServiceName(ifaceName string, r io.Reader) string {
|
||||
scanner := bufio.NewScanner(r)
|
||||
prevLine := ""
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if strings.Contains(line, "*") {
|
||||
// Network services is disabled.
|
||||
continue
|
||||
}
|
||||
if !strings.Contains(line, "Device: "+ifaceName) {
|
||||
prevLine = line
|
||||
continue
|
||||
}
|
||||
parts := strings.SplitN(prevLine, " ", 2)
|
||||
if len(parts) == 2 {
|
||||
return strings.TrimSpace(parts[1])
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
@@ -1,7 +0,0 @@
|
||||
//go:build !darwin
|
||||
|
||||
package main
|
||||
|
||||
import "net"
|
||||
|
||||
func patchNetIfaceName(iface *net.Interface) error { return nil }
|
||||
@@ -1,27 +0,0 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"github.com/vishvananda/netlink"
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
func (p *prog) watchLinkState() {
|
||||
ch := make(chan netlink.LinkUpdate)
|
||||
done := make(chan struct{})
|
||||
defer close(done)
|
||||
if err := netlink.LinkSubscribe(ch, done); err != nil {
|
||||
mainLog.Warn().Err(err).Msg("could not subscribe link")
|
||||
return
|
||||
}
|
||||
for lu := range ch {
|
||||
if lu.Change == 0xFFFFFFFF {
|
||||
continue
|
||||
}
|
||||
if lu.Change&unix.IFF_UP != 0 {
|
||||
mainLog.Debug().Msgf("link state changed, re-bootstrapping")
|
||||
for _, uc := range p.cfg.Upstream {
|
||||
uc.ReBootstrap()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,5 +0,0 @@
|
||||
//go:build !linux
|
||||
|
||||
package main
|
||||
|
||||
func (p *prog) watchLinkState() {}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user