mirror of
https://github.com/Control-D-Inc/ctrld.git
synced 2026-02-03 22:18:39 +00:00
Compare commits
653 Commits
release-br
...
release-br
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2e53fa4274 | ||
|
|
d0e66b83d0 | ||
|
|
34fef77ff7 | ||
|
|
7006e967e4 | ||
|
|
f9d026334a | ||
|
|
36d4192c05 | ||
|
|
90eddb8268 | ||
|
|
c13a3c3c17 | ||
|
|
d42a78cba9 | ||
|
|
92f32ba16e | ||
|
|
4c838f6a5e | ||
|
|
adc0e1a51e | ||
|
|
3afdaef6e6 | ||
|
|
ef7432df55 | ||
|
|
fb807d7c37 | ||
|
|
f7c124d99d | ||
|
|
ed826f7a95 | ||
|
|
56f8113bb0 | ||
|
|
a04babbbc3 | ||
|
|
59b98245d3 | ||
|
|
f6be1ab1fb | ||
|
|
54f58cc2e5 | ||
|
|
eb8c5bc3fa | ||
|
|
3bcad10f92 | ||
|
|
d87a0a69c8 | ||
|
|
b7202f8469 | ||
|
|
a084c87370 | ||
|
|
5d87bd07ca | ||
|
|
3412d1f8b9 | ||
|
|
a72ff1e769 | ||
|
|
2c98b2c545 | ||
|
|
4792183c0d | ||
|
|
d88c860cac | ||
|
|
8b605da861 | ||
|
|
7cda5d7646 | ||
|
|
0cd873a88f | ||
|
|
1ff5d1f05a | ||
|
|
954395fa29 | ||
|
|
ea98a59aba | ||
|
|
6971d392b7 | ||
|
|
a2f8313668 | ||
|
|
af05cb2d94 | ||
|
|
5f0b9a24b9 | ||
|
|
37523fdc45 | ||
|
|
ca505f1140 | ||
|
|
a22f0579d5 | ||
|
|
9f656269ac | ||
|
|
13de41d854 | ||
|
|
42ea5f7fed | ||
|
|
af9386568f | ||
|
|
13b15e642d | ||
|
|
d4df2e7f72 | ||
|
|
5b8ed3a72f | ||
|
|
59fe94112d | ||
|
|
0ab51cdad7 | ||
|
|
0a1d6fa4db | ||
|
|
6e10bba7fe | ||
|
|
fc8268b70a | ||
|
|
b5f101f667 | ||
|
|
69b192c6fa | ||
|
|
ec85b1621d | ||
|
|
ddbb0f0db4 | ||
|
|
2996a161cd | ||
|
|
35e2a20019 | ||
|
|
65a300a807 | ||
|
|
a67aea88be | ||
|
|
84d4491a18 | ||
|
|
05d183c94b | ||
|
|
41282d0f51 | ||
|
|
f7fb555c89 | ||
|
|
2e63624f6c | ||
|
|
b2a54db4b5 | ||
|
|
0ef02bc15e | ||
|
|
b18cd7ee83 | ||
|
|
a16b25ad1d | ||
|
|
59ece456b1 | ||
|
|
d5cb327620 | ||
|
|
7a2277bc18 | ||
|
|
c736f4c1e9 | ||
|
|
f0cb810dd6 | ||
|
|
64632fa640 | ||
|
|
b9b9cfcade | ||
|
|
fc527dbdfb | ||
|
|
5641aab5bd | ||
|
|
31517ce750 | ||
|
|
51e58b64a5 | ||
|
|
3ca559e5a4 | ||
|
|
0e3f764299 | ||
|
|
e52402eb0c | ||
|
|
2133f31854 | ||
|
|
a198a5cd65 | ||
|
|
eb2b231bd2 | ||
|
|
7af29cfbc0 | ||
|
|
ce1a165348 | ||
|
|
fd48e6d795 | ||
|
|
d71d1341b6 | ||
|
|
21855df4af | ||
|
|
66e2d3a40a | ||
|
|
26257cf24a | ||
|
|
36a7423634 | ||
|
|
e616091249 | ||
|
|
0948161529 | ||
|
|
ce29b5d217 | ||
|
|
de24fa293e | ||
|
|
6663925c4d | ||
|
|
b9ece6d7b9 | ||
|
|
c4efa1ab97 | ||
|
|
7cea5305e1 | ||
|
|
a20fbf95de | ||
|
|
628c4302aa | ||
|
|
8dc34f8bf5 | ||
|
|
b4faf82f76 | ||
|
|
a983dfaee2 | ||
|
|
62f73bcaa2 | ||
|
|
00e9d2bdd3 | ||
|
|
ace3b1e66e | ||
|
|
d1ea1ba08c | ||
|
|
c06c8aa859 | ||
|
|
0c2cc00c4f | ||
|
|
8d6ea91f35 | ||
|
|
7dfb77228f | ||
|
|
24910f1fa6 | ||
|
|
433a61d2ee | ||
|
|
3937e885f0 | ||
|
|
c651003cc4 | ||
|
|
b7ccfcb8b4 | ||
|
|
a9ed70200b | ||
|
|
c6365e6b74 | ||
|
|
dacc67e50f | ||
|
|
c60cf33af3 | ||
|
|
f27cbe3525 | ||
|
|
2de1b9929a | ||
|
|
8bf654aece | ||
|
|
84376ed719 | ||
|
|
7a136b8874 | ||
|
|
58c0e4f15a | ||
|
|
e0d35d8ba2 | ||
|
|
3b2e48761e | ||
|
|
b27064008e | ||
|
|
1ad63827e1 | ||
|
|
20e61550c2 | ||
|
|
020b814402 | ||
|
|
e578867118 | ||
|
|
46a1039f21 | ||
|
|
cc9e27de5f | ||
|
|
6ab3ab9faf | ||
|
|
e68bfa795a | ||
|
|
e60a92e93e | ||
|
|
62fe14f76b | ||
|
|
a0c5062e3a | ||
|
|
49eb152d02 | ||
|
|
b05056423a | ||
|
|
c7168739c7 | ||
|
|
5b1faf1ce3 | ||
|
|
513a6f9ec7 | ||
|
|
8db6fa4232 | ||
|
|
5036de2602 | ||
|
|
332f8ccc37 | ||
|
|
a582195cec | ||
|
|
9fe36ae984 | ||
|
|
54cb455522 | ||
|
|
8bd3b9e474 | ||
|
|
eff5ff580b | ||
|
|
c45f863ed8 | ||
|
|
414d4e356d | ||
|
|
ef697eb781 | ||
|
|
0631ffe831 | ||
|
|
7444d8517a | ||
|
|
3480043e40 | ||
|
|
619b6e7516 | ||
|
|
0123ca44fb | ||
|
|
7929aafe2a | ||
|
|
dc433f8dc9 | ||
|
|
8ccaeeab60 | ||
|
|
043a28eb33 | ||
|
|
c329402f5d | ||
|
|
23e6ad6e1f | ||
|
|
e6de78c1fa | ||
|
|
a670708f93 | ||
|
|
4ebe2fb5f4 | ||
|
|
3403b2039d | ||
|
|
e30ad31e0f | ||
|
|
81e0bad739 | ||
|
|
7d07d738dc | ||
|
|
0fae584e65 | ||
|
|
9e83085f2a | ||
|
|
41a00c68ac | ||
|
|
e3b99bf339 | ||
|
|
5007a87d3a | ||
|
|
60e65a37a6 | ||
|
|
d37d0e942c | ||
|
|
98042d8dbd | ||
|
|
af4b826b68 | ||
|
|
253a57ca01 | ||
|
|
caf98b4dfe | ||
|
|
398f71fd00 | ||
|
|
e1301ade96 | ||
|
|
7a23f82192 | ||
|
|
715bcc4aa1 | ||
|
|
0c74838740 | ||
|
|
4b05b6da7b | ||
|
|
375844ff1a | ||
|
|
1d207379cb | ||
|
|
fb49cb71e3 | ||
|
|
9618efbcde | ||
|
|
bb2210b06a | ||
|
|
917052723d | ||
|
|
fef85cadeb | ||
|
|
4a05fb6b28 | ||
|
|
6644ce53f2 | ||
|
|
72f0b89fdc | ||
|
|
41a97a6609 | ||
|
|
38064d6ad5 | ||
|
|
ae6945cedf | ||
|
|
3132d1b032 | ||
|
|
2716ae29bd | ||
|
|
1c50c2b6af | ||
|
|
cf6d16b439 | ||
|
|
60686f55ff | ||
|
|
47d7ace3a7 | ||
|
|
2d3779ec27 | ||
|
|
595071b608 | ||
|
|
57ef717080 | ||
|
|
eb27d1482b | ||
|
|
f57972ead7 | ||
|
|
168eaf538b | ||
|
|
1560455ca3 | ||
|
|
028475a193 | ||
|
|
f7a6dbe39b | ||
|
|
e573a490c9 | ||
|
|
ce3281e70d | ||
|
|
0fbfd160c9 | ||
|
|
20759017e6 | ||
|
|
69e0aab73e | ||
|
|
7ed6733fb7 | ||
|
|
9718ab8579 | ||
|
|
2687a4a018 | ||
|
|
2d9c60dea1 | ||
|
|
841be069b7 | ||
|
|
7833132917 | ||
|
|
e9e63b0983 | ||
|
|
4df470b869 | ||
|
|
89600f6091 | ||
|
|
f986a575e8 | ||
|
|
9c2fe8d21f | ||
|
|
8bcbb9249e | ||
|
|
a95d50c0af | ||
|
|
5db7d3577b | ||
|
|
c53a0ca1c4 | ||
|
|
6fd3d1788a | ||
|
|
087c1975e5 | ||
|
|
3713cbecc3 | ||
|
|
6046789fa4 | ||
|
|
3ea69b180c | ||
|
|
db6e977e3a | ||
|
|
a5c776c846 | ||
|
|
5a566c028a | ||
|
|
ff43c74d8d | ||
|
|
3c7255569c | ||
|
|
4a92ec4d2d | ||
|
|
9bbccb4082 | ||
|
|
4f62314646 | ||
|
|
cb49d0d947 | ||
|
|
89f7874fc6 | ||
|
|
221917e80b | ||
|
|
37d41bd215 | ||
|
|
8a96b8bec4 | ||
|
|
02ee113b95 | ||
|
|
f71dd78915 | ||
|
|
cd5619a05b | ||
|
|
a63a30c76b | ||
|
|
f5ba8be182 | ||
|
|
a9f76322bd | ||
|
|
ed39269c80 | ||
|
|
09426dcd36 | ||
|
|
17941882a9 | ||
|
|
70ab8032a0 | ||
|
|
8360bdc50a | ||
|
|
6837176ec7 | ||
|
|
5e9b4244e7 | ||
|
|
9b6a308958 | ||
|
|
71e327653a | ||
|
|
a56711796f | ||
|
|
09495f2a7c | ||
|
|
484643e114 | ||
|
|
da91aabc35 | ||
|
|
c654398981 | ||
|
|
47a90ec2a1 | ||
|
|
2875e22d0b | ||
|
|
c5d14e0075 | ||
|
|
84e06c363c | ||
|
|
5b9ccc5065 | ||
|
|
6ca1a7ccc7 | ||
|
|
9d666be5d4 | ||
|
|
65de7edcde | ||
|
|
0cdff0d368 | ||
|
|
f87220a908 | ||
|
|
30ea0c6499 | ||
|
|
9501e35c60 | ||
|
|
5ac9d17bdf | ||
|
|
cb14992ddc | ||
|
|
e88372fc8c | ||
|
|
b320662d67 | ||
|
|
ce353cd4d9 | ||
|
|
4befd33866 | ||
|
|
4b36e3ac44 | ||
|
|
f507bc8f9e | ||
|
|
14c88f4a6d | ||
|
|
3e388c2857 | ||
|
|
cfe1209d61 | ||
|
|
5a88a7c22c | ||
|
|
8c661c4401 | ||
|
|
e6f256d640 | ||
|
|
ede354166b | ||
|
|
282a8ce78e | ||
|
|
08fe04f1ee | ||
|
|
082d14a9ba | ||
|
|
617674ce43 | ||
|
|
7088df58dd | ||
|
|
9cbd9b3e44 | ||
|
|
e6586fd360 | ||
|
|
33a6db2599 | ||
|
|
70b0c4f7b9 | ||
|
|
5af3ec4f7b | ||
|
|
79476add12 | ||
|
|
1634a06330 | ||
|
|
a007394f60 | ||
|
|
62a0ba8731 | ||
|
|
e8d3ed1acd | ||
|
|
8b98faa441 | ||
|
|
30320ec9c7 | ||
|
|
5f4a399850 | ||
|
|
82e0d4b0c4 | ||
|
|
95a9df826d | ||
|
|
3b71d26cf3 | ||
|
|
c233ad9b1b | ||
|
|
12d6484b1c | ||
|
|
bc7b1cc6d8 | ||
|
|
ec684348ed | ||
|
|
18a19a3aa2 | ||
|
|
905f2d08c5 | ||
|
|
04947b4d87 | ||
|
|
72bf80533e | ||
|
|
9ddedf926e | ||
|
|
139dd62ff3 | ||
|
|
50ef00526e | ||
|
|
80cf79b9cb | ||
|
|
e6ad39b070 | ||
|
|
56f9c72569 | ||
|
|
dc48c908b8 | ||
|
|
9b0f0e792a | ||
|
|
b3eebb19b6 | ||
|
|
c24589a5be | ||
|
|
1e1c5a4dc8 | ||
|
|
339023421a | ||
|
|
a00d2a431a | ||
|
|
5aca118dbb | ||
|
|
411f7434f4 | ||
|
|
34801382f5 | ||
|
|
b9f2259ae4 | ||
|
|
19020a96bf | ||
|
|
96085147ff | ||
|
|
f3dd344026 | ||
|
|
486096416f | ||
|
|
5710f2e984 | ||
|
|
09936f1f07 | ||
|
|
0d6ca57536 | ||
|
|
3ddcb84db8 | ||
|
|
1012bf063f | ||
|
|
b8155e6182 | ||
|
|
9a34df61bb | ||
|
|
fbb879edf9 | ||
|
|
ac97c88876 | ||
|
|
a1fda2c0de | ||
|
|
f499770d45 | ||
|
|
4769da4ef4 | ||
|
|
c2556a8e39 | ||
|
|
29bf329f6a | ||
|
|
1dee4305bc | ||
|
|
429a98b690 | ||
|
|
da01a146d2 | ||
|
|
dd9f2465be | ||
|
|
b5cf0e2b31 | ||
|
|
1db159ad34 | ||
|
|
6604f973ac | ||
|
|
69ee6582e2 | ||
|
|
6f12667e8c | ||
|
|
b002dff624 | ||
|
|
affef963c1 | ||
|
|
56b2056190 | ||
|
|
c1e6f5126a | ||
|
|
1a8c1ec73d | ||
|
|
52954b8ceb | ||
|
|
a5025e35ea | ||
|
|
07f80c9ebf | ||
|
|
13db23553d | ||
|
|
3963fce43b | ||
|
|
ea4e5147bd | ||
|
|
7a491a4cc5 | ||
|
|
5ba90748f6 | ||
|
|
20f8f22bae | ||
|
|
b50cccac85 | ||
|
|
34ebe9b054 | ||
|
|
43d82cf1a7 | ||
|
|
ab88174091 | ||
|
|
ebcbf85373 | ||
|
|
87513cba6d | ||
|
|
64bcd2f00d | ||
|
|
cc6ae290f8 | ||
|
|
3e62bd3dbd | ||
|
|
8491f9c455 | ||
|
|
3ca754b438 | ||
|
|
8c7c3901e8 | ||
|
|
a9672dfff5 | ||
|
|
203a2ec8b8 | ||
|
|
810cbd1f4f | ||
|
|
49eebcdcbc | ||
|
|
e89021ec3a | ||
|
|
73a697b2fa | ||
|
|
9319d08046 | ||
|
|
7dc5138e91 | ||
|
|
8f189c919a | ||
|
|
906479a15c | ||
|
|
dabbf2037b | ||
|
|
b496147ce7 | ||
|
|
583718f234 | ||
|
|
fdb82f6ec3 | ||
|
|
5145729ab1 | ||
|
|
4d810261a4 | ||
|
|
18e8616834 | ||
|
|
d55563cac5 | ||
|
|
bb481d9bcc | ||
|
|
a163be3584 | ||
|
|
891b7cb2c6 | ||
|
|
176c22f229 | ||
|
|
faa0ed06b6 | ||
|
|
9515db7faf | ||
|
|
d822bf4257 | ||
|
|
0826671809 | ||
|
|
67d74774a9 | ||
|
|
5d65416227 | ||
|
|
49441f62f3 | ||
|
|
99651f6e5b | ||
|
|
edca1f4f89 | ||
|
|
3d834f00f6 | ||
|
|
6bb9e7a766 | ||
|
|
61fb71b1fa | ||
|
|
f8967c376f | ||
|
|
6d3c86c0be | ||
|
|
e42554f892 | ||
|
|
28984090e5 | ||
|
|
251255c746 | ||
|
|
32709dc64c | ||
|
|
71f26a6d81 | ||
|
|
44352f8006 | ||
|
|
af38623590 | ||
|
|
9c1665a759 | ||
|
|
eaad24e5e5 | ||
|
|
cfaf32f71a | ||
|
|
51b235b61a | ||
|
|
0a6d9d4454 | ||
|
|
dc700bbd52 | ||
|
|
cb445825f4 | ||
|
|
4d996e317b | ||
|
|
30c9012004 | ||
|
|
2a23feaf4b | ||
|
|
b82ad3720c | ||
|
|
8d2cb6091e | ||
|
|
3023f33dff | ||
|
|
22e97e981a | ||
|
|
44484e1231 | ||
|
|
eac60b87c7 | ||
|
|
8db28cb76e | ||
|
|
8dbe828b99 | ||
|
|
5c24acd952 | ||
|
|
998b9a5c5d | ||
|
|
0084e9ef26 | ||
|
|
122600bff2 | ||
|
|
41846b6d4c | ||
|
|
dfbcb1489d | ||
|
|
684019c2e3 | ||
|
|
e92619620d | ||
|
|
cebfd12d5c | ||
|
|
874ff01ab8 | ||
|
|
0bb8703f78 | ||
|
|
0bb51aa71d | ||
|
|
af2c1c87e0 | ||
|
|
8939debbc0 | ||
|
|
7591a0ccc6 | ||
|
|
c3ff8182af | ||
|
|
5897c174d3 | ||
|
|
f9a3f4c045 | ||
|
|
a2cb895cdc | ||
|
|
2bebe93e47 | ||
|
|
28ec1869fc | ||
|
|
17f6d7a77b | ||
|
|
9e6e647ff8 | ||
|
|
a2116e5eb5 | ||
|
|
564c9ef712 | ||
|
|
856abb71b7 | ||
|
|
0a30fdea69 | ||
|
|
4f125cf107 | ||
|
|
494d8be777 | ||
|
|
cd9c750884 | ||
|
|
91d319804b | ||
|
|
180eae60f2 | ||
|
|
d01f5c2777 | ||
|
|
294a90a807 | ||
|
|
c3b4ae9c79 | ||
|
|
09188bedf7 | ||
|
|
4614b98e94 | ||
|
|
990bc620f7 | ||
|
|
efb5a92571 | ||
|
|
8e0a96a44c | ||
|
|
43ff2f648c | ||
|
|
4816a09e3a | ||
|
|
3fea92c8b1 | ||
|
|
63f959c951 | ||
|
|
44ba6aadd9 | ||
|
|
d88cf52b4e | ||
|
|
58a00ea24a | ||
|
|
712b23a4bb | ||
|
|
baf836557c | ||
|
|
904b23eeac | ||
|
|
6aafe445f5 | ||
|
|
ebd516855b | ||
|
|
df4e04719e | ||
|
|
2440d922c6 | ||
|
|
f1b8d1c4ad | ||
|
|
79076bda35 | ||
|
|
9d2ea15346 | ||
|
|
77c1113ff7 | ||
|
|
e03ad4cd77 | ||
|
|
6e28517454 | ||
|
|
8ddbf881b3 | ||
|
|
c58516cfb0 | ||
|
|
34758f6205 | ||
|
|
a9959a6f3d | ||
|
|
511c4e696f | ||
|
|
bed7435b0c | ||
|
|
507c1afd59 | ||
|
|
2765487f10 | ||
|
|
80a88811cd | ||
|
|
823195c504 | ||
|
|
0f3e8c7ada | ||
|
|
ee5eb4fc4e | ||
|
|
d58d8074f4 | ||
|
|
94a0530991 | ||
|
|
073af0f89c | ||
|
|
6028b8f186 | ||
|
|
126477ef88 | ||
|
|
13391fd469 | ||
|
|
82e44b01af | ||
|
|
e355fd70ab | ||
|
|
d5c171735e | ||
|
|
b175368794 | ||
|
|
bcf4c25ba8 | ||
|
|
11b09af76d | ||
|
|
af0380a96a | ||
|
|
f39512b4c0 | ||
|
|
7ce62ccaec | ||
|
|
44c0a06996 | ||
|
|
f7d3db06c6 | ||
|
|
0ca37dc707 | ||
|
|
2bcba7b578 | ||
|
|
829e93c079 | ||
|
|
4896563e3c | ||
|
|
0c096d5f07 | ||
|
|
ab8f072388 | ||
|
|
32219e7d32 | ||
|
|
d292e03d1b | ||
|
|
5dd6336953 | ||
|
|
854a244ebb | ||
|
|
125b4b6077 | ||
|
|
46e8d4fad7 | ||
|
|
e5389ffecb | ||
|
|
46509be8a0 | ||
|
|
d3d2ed539f | ||
|
|
8496adc638 | ||
|
|
e1d078a2c3 | ||
|
|
0dee7518c4 | ||
|
|
774f07dd7f | ||
|
|
c271896551 | ||
|
|
82d887f52d | ||
|
|
6e27f877ff | ||
|
|
39a2cab051 | ||
|
|
72d2f4e7e3 | ||
|
|
19bc44a7f3 | ||
|
|
59dc74ffbb | ||
|
|
12c8ab696f | ||
|
|
28f32bd7e5 | ||
|
|
6b43639be5 | ||
|
|
6be80e4827 | ||
|
|
437fb1b16d | ||
|
|
61b6431b6e | ||
|
|
7ccecdd9f7 | ||
|
|
e43b2b5530 | ||
|
|
2cd8b7e021 | ||
|
|
d6768c4c39 | ||
|
|
59a895bfe2 | ||
|
|
cacd957594 | ||
|
|
2cd063ebd6 | ||
|
|
9ed8e49a08 | ||
|
|
66cb7cc21d | ||
|
|
4bf09120ff | ||
|
|
be0769e433 | ||
|
|
7b476e38be | ||
|
|
0a7d3445f4 | ||
|
|
76d2e2c226 | ||
|
|
3007cb86ec | ||
|
|
fa3af372ab | ||
|
|
48a780fc3e | ||
|
|
28df551195 | ||
|
|
e65a71b2ae | ||
|
|
dc61fd2554 | ||
|
|
a4edf266f0 | ||
|
|
7af59ee589 | ||
|
|
3f3c1d6d78 | ||
|
|
ab1d7fd796 | ||
|
|
6c2996a921 | ||
|
|
de32dd8ba4 | ||
|
|
d43e50ee2d | ||
|
|
aec2596262 | ||
|
|
78a7c87ecc | ||
|
|
1d3f8757bc | ||
|
|
c0c69d0739 | ||
|
|
1aa991298a | ||
|
|
f3a3227f21 | ||
|
|
a4c1983657 | ||
|
|
cc28b92935 | ||
|
|
eaa907a647 | ||
|
|
de951fd895 | ||
|
|
3f211d3cc2 | ||
|
|
2f46d512c6 | ||
|
|
12148ec231 | ||
|
|
9fe6af684f | ||
|
|
472bb05e95 | ||
|
|
50bfed706d | ||
|
|
350d8355b1 | ||
|
|
03781d4cec | ||
|
|
67e4afc06e | ||
|
|
32482809b7 | ||
|
|
c315d21be9 | ||
|
|
48b2031269 | ||
|
|
41139b3343 | ||
|
|
d5e6c7b13f | ||
|
|
60d6734e1f | ||
|
|
e684c7d8c4 | ||
|
|
ce35383341 | ||
|
|
5553490b27 | ||
|
|
eaf39f48a0 | ||
|
|
a5ddbdcb42 | ||
|
|
0c99d27be5 |
2
.dockerignore
Normal file
2
.dockerignore
Normal file
@@ -0,0 +1,2 @@
|
||||
Dockerfile
|
||||
.git/
|
||||
8
.github/workflows/ci.yml
vendored
8
.github/workflows/ci.yml
vendored
@@ -9,18 +9,18 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
os: ["windows-latest", "ubuntu-latest", "macOS-latest"]
|
||||
go: ["1.20.x"]
|
||||
go: ["1.24.x"]
|
||||
runs-on: ${{ matrix.os }}
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 1
|
||||
- uses: WillAbides/setup-go-faster@v1.8.0
|
||||
- uses: actions/setup-go@v6
|
||||
with:
|
||||
go-version: ${{ matrix.go }}
|
||||
- run: "go test -race ./..."
|
||||
- uses: dominikh/staticcheck-action@v1.2.0
|
||||
- uses: dominikh/staticcheck-action@v1.4.0
|
||||
with:
|
||||
version: "2023.1.2"
|
||||
version: "2025.1.1"
|
||||
install-go: false
|
||||
cache-key: ${{ matrix.go }}
|
||||
|
||||
9
.gitignore
vendored
9
.gitignore
vendored
@@ -3,3 +3,12 @@ gon.hcl
|
||||
|
||||
/Build
|
||||
.DS_Store
|
||||
|
||||
# Release folder
|
||||
dist/
|
||||
|
||||
# Binaries
|
||||
ctrld-*
|
||||
|
||||
# generated file
|
||||
cmd/cli/rsrc_*.syso
|
||||
|
||||
232
README.md
232
README.md
@@ -4,11 +4,15 @@
|
||||
[](https://pkg.go.dev/github.com/Control-D-Inc/ctrld)
|
||||
[](https://goreportcard.com/report/github.com/Control-D-Inc/ctrld)
|
||||
|
||||

|
||||
|
||||
A highly configurable DNS forwarding proxy with support for:
|
||||
- Multiple listeners for incoming queries
|
||||
- Multiple upstreams with fallbacks
|
||||
- Multiple network policy driven DNS query steering
|
||||
- Multiple network policy driven DNS query steering (via network cidr, MAC address or FQDN)
|
||||
- Policy driven domain based "split horizon" DNS with wildcard support
|
||||
- LAN client discovery via DHCP, mDNS, ARP, NDP, hosts file parsing
|
||||
- Prometheus metrics exporter
|
||||
|
||||
## TLDR
|
||||
Proxy legacy DNS traffic to secure DNS upstreams in highly configurable ways.
|
||||
@@ -21,47 +25,58 @@ All DNS protocols are supported, including:
|
||||
- `DNS-over-QUIC`
|
||||
|
||||
# Use Cases
|
||||
1. Use secure DNS protocols on networks and devices that don't natively support them (legacy routers, legacy OSes, TVs, smart toasters).
|
||||
1. Use secure DNS protocols on networks and devices that don't natively support them (legacy OSes, TVs, smart toasters).
|
||||
2. Create source IP based DNS routing policies with variable secure DNS upstreams. Subnet 1 (admin) uses upstream resolver A, while Subnet 2 (employee) uses upstream resolver B.
|
||||
3. Create destination IP based DNS routing policies with variable secure DNS upstreams. Listener 1 uses upstream resolver C, while Listener 2 uses upstream resolver D.
|
||||
4. Create domain level "split horizon" DNS routing policies to send internal domains (*.company.int) to a local DNS server, while everything else goes to another upstream.
|
||||
|
||||
|
||||
## OS Support
|
||||
- Windows (386, amd64, arm)
|
||||
- Mac (amd64, arm64)
|
||||
- Windows Desktop (386, amd64, arm)
|
||||
- MacOS (amd64, arm64)
|
||||
- Linux (386, amd64, arm, mips)
|
||||
- Common routers (See Router Mode below)
|
||||
- FreeBSD (386, amd64, arm)
|
||||
|
||||
# Install
|
||||
There are several ways to download and install `ctrld.
|
||||
There are several ways to download and install `ctrld`.
|
||||
|
||||
## Quick Install
|
||||
The simplest way to download and install `ctrld` is to use the following installer command on any UNIX-like platform:
|
||||
|
||||
```shell
|
||||
sh -c 'sh -c "$(curl -sL https://api.controld.com/dl)"'
|
||||
sh -c 'sh -c "$(curl -sL https://api.controld.com/dl?version=2)"'
|
||||
```
|
||||
|
||||
Windows user and prefer Powershell (who doesn't)? No problem, execute this command instead in administrative cmd:
|
||||
Windows user and prefer Powershell (who doesn't)? No problem, execute this command instead in administrative PowerShell:
|
||||
```shell
|
||||
powershell -Command "(Invoke-WebRequest -Uri 'https://api.controld.com/dl' -UseBasicParsing).Content | Set-Content 'ctrld_install.bat'" && ctrld_install.bat
|
||||
(Invoke-WebRequest -Uri 'https://api.controld.com/dl/ps1?version=2' -UseBasicParsing).Content | Set-Content "$env:TEMPctrld_install.ps1"; Invoke-Expression "& '$env:TEMPctrld_install.ps1'"
|
||||
```
|
||||
|
||||
Or you can pull and run a Docker container from [Docker Hub](https://hub.docker.com/r/controldns/ctrld)
|
||||
```shell
|
||||
docker run -d --name=ctrld -p 127.0.0.1:53:53/tcp -p 127.0.0.1:53:53/udp controldns/ctrld:latest
|
||||
```
|
||||
|
||||
## Download Manually
|
||||
Alternatively, if you know what you're doing you can download pre-compiled binaries from the [Releases](https://github.com/Control-D-Inc/ctrld/releases) section for the appropriate platform.
|
||||
|
||||
## Build
|
||||
Lastly, you can build `ctrld` from source which requires `go1.19+`:
|
||||
Lastly, you can build `ctrld` from source which requires `go1.23+`:
|
||||
|
||||
```shell
|
||||
$ go build ./cmd/ctrld
|
||||
go build ./cmd/ctrld
|
||||
```
|
||||
|
||||
or
|
||||
|
||||
```shell
|
||||
$ go install github.com/Control-D-Inc/ctrld/cmd/ctrld@latest
|
||||
go install github.com/Control-D-Inc/ctrld/cmd/ctrld@latest
|
||||
```
|
||||
|
||||
or
|
||||
|
||||
```shell
|
||||
docker build -t controldns/ctrld . -f docker/Dockerfile
|
||||
```
|
||||
|
||||
|
||||
@@ -82,129 +97,118 @@ Usage:
|
||||
|
||||
Available Commands:
|
||||
run Run the DNS proxy server
|
||||
service Manage ctrld service
|
||||
start Quick start service and configure DNS on interface
|
||||
stop Quick stop service and remove DNS from interface
|
||||
setup Auto-setup Control D on a router.
|
||||
|
||||
Supported platforms:
|
||||
|
||||
ₒ ddwrt
|
||||
ₒ merlin
|
||||
ₒ openwrt
|
||||
ₒ ubios
|
||||
ₒ auto - detect the platform you are running on
|
||||
|
||||
restart Restart the ctrld service
|
||||
reload Reload the ctrld service
|
||||
status Show status of the ctrld service
|
||||
uninstall Stop and uninstall the ctrld service
|
||||
service Manage ctrld service
|
||||
clients Manage clients
|
||||
upgrade Upgrading ctrld to latest version
|
||||
log Manage runtime debug logs
|
||||
|
||||
Flags:
|
||||
-h, --help help for ctrld
|
||||
-s, --silent do not write any log output
|
||||
-v, --verbose count verbose log output, "-v" basic logging, "-vv" debug level logging
|
||||
-v, --verbose count verbose log output, "-v" basic logging, "-vv" debug logging
|
||||
--version version for ctrld
|
||||
|
||||
Use "ctrld [command] --help" for more information about a command.
|
||||
```
|
||||
|
||||
## Basic Run Mode
|
||||
To start the server with default configuration, simply run: `./ctrld run`. This will create a generic `ctrld.toml` file in the **working directory** and start the application in foreground.
|
||||
1. Start the server
|
||||
```
|
||||
$ sudo ./ctrld run
|
||||
This is the most basic way to run `ctrld`, in foreground mode. Unless you already have a config file, a default one will be generated.
|
||||
|
||||
### Command
|
||||
|
||||
Windows (Admin Shell)
|
||||
```shell
|
||||
ctrld.exe run
|
||||
```
|
||||
|
||||
2. Run a test query using a DNS client, for example, `dig`:
|
||||
Linux or Macos
|
||||
```shell
|
||||
sudo ctrld run
|
||||
```
|
||||
|
||||
You can then run a test query using a DNS client, for example, `dig`:
|
||||
```
|
||||
$ dig verify.controld.com @127.0.0.1 +short
|
||||
api.controld.com.
|
||||
147.185.34.1
|
||||
```
|
||||
|
||||
If `verify.controld.com` resolves, you're successfully using the default Control D upstream. From here, you can start editing the config file and go nuts with it. To enforce a new config, restart the server.
|
||||
If `verify.controld.com` resolves, you're successfully using the default Control D upstream. From here, you can start editing the config file that was generated. To enforce a new config, restart the server.
|
||||
|
||||
## Service Mode
|
||||
To run the application in service mode on any Windows, MacOS or Linux distibution, simply run: `./ctrld start` as system/root user. This will create a generic `ctrld.toml` file in the **user home** directory (on Windows) or `/etc/controld/` (everywhere else), start the system service, and configure the listener on the default network interface. Service will start on OS boot.
|
||||
This mode will run the application as a background system service on any Windows, MacOS, Linux or FreeBSD distribution. This will create a generic `ctrld.toml` file in the **C:\ControlD** directory (on Windows) or `/etc/controld/` (almost everywhere else), start the system service, and **configure the listener on all physical network interface**. Service will start on OS boot.
|
||||
|
||||
In order to stop the service, and restore your DNS to original state, simply run `./ctrld stop`. If you wish to uninstall the service permanently, run `./ctrld service uninstall`.
|
||||
### Command
|
||||
|
||||
For granular control of the service, run the `service` command. Each sub-command has its own help section so you can see what arguments you can supply.
|
||||
Windows (Admin Shell)
|
||||
```shell
|
||||
ctrld.exe start
|
||||
```
|
||||
|
||||
```
|
||||
Manage ctrld service
|
||||
Linux or Macos
|
||||
```
|
||||
sudo ctrld start
|
||||
```
|
||||
|
||||
Usage:
|
||||
ctrld service [command]
|
||||
If `ctrld` is not in your system path (you installed it manually), you will need to run the above commands from the directory where you installed `ctrld`.
|
||||
|
||||
Available Commands:
|
||||
interfaces Manage network interfaces
|
||||
restart Restart the ctrld service
|
||||
start Start the ctrld service
|
||||
status Show status of the ctrld service
|
||||
stop Stop the ctrld service
|
||||
uninstall Uninstall the ctrld service
|
||||
In order to stop the service, and restore your DNS to original state, simply run `ctrld stop`. If you wish to stop and uninstall the service permanently, run `ctrld uninstall`.
|
||||
|
||||
Flags:
|
||||
-h, --help help for service
|
||||
## Unmanaged Service Mode
|
||||
This mode functions similarly to the "Service Mode" above except it will simply start a system service and the config defined listeners, but **will not make any changes to any network interfaces**. You can then set the `ctrld` listener(s) IP on the desired network interfaces manually.
|
||||
|
||||
Global Flags:
|
||||
-v, --verbose count verbose log output, "-v" basic logging, "-vv" debug level logging
|
||||
### Command
|
||||
|
||||
Use "ctrld service [command] --help" for more information about a command.
|
||||
```
|
||||
Windows (Admin Shell)
|
||||
```shell
|
||||
ctrld.exe service start
|
||||
```
|
||||
|
||||
## Router Mode
|
||||
You can run `ctrld` on any supported router, which will function similarly to the Service Mode mentioned above. The list of supported routers and firmware includes:
|
||||
- OpenWRT
|
||||
- DD-WRT
|
||||
- Asus Merlin
|
||||
- GL.iNet
|
||||
- Ubiquiti
|
||||
|
||||
In order to start `ctrld` as a DNS provider, simply run `./ctrld setup auto` command.
|
||||
|
||||
In this mode, and when Control D upstreams are used, the router will [relay your network topology](https://docs.controld.com/docs/device-clients) to Control D (LAN IPs, MAC addresses, and hostnames), and you will be able to see your LAN devices in the web panel, view analytics and apply unique profiles to them.
|
||||
|
||||
### Control D Auto Configuration
|
||||
Application can be started with a specific resolver config, instead of the default one. Simply supply your Resolver ID with a `--cd` flag, when using the `run` (foreground) or `start` (service) or `setup` (router) modes.
|
||||
|
||||
The following command will start the application in foreground mode, using the free "p2" resolver, which blocks Ads & Trackers.
|
||||
|
||||
```shell
|
||||
./ctrld run --cd p2
|
||||
```
|
||||
|
||||
Alternatively, you can use your own personal Control D Device resolver, and start the application in service mode. Your resolver ID is displayed on the "Show Resolvers" screen for the relevant Control D Device.
|
||||
|
||||
```shell
|
||||
./ctrld start --cd abcd1234
|
||||
```
|
||||
|
||||
You can do the same while starting in router mode:
|
||||
```shell
|
||||
./ctrld setup auto --cd abcd1234
|
||||
```
|
||||
|
||||
Once you run the above commands (in service or router modes only), the following things will happen:
|
||||
- You resolver configuration will be fetched from the API, and config file templated with the resolver data
|
||||
- Application will start as a service, and keep running (even after reboot) until you run the `stop` or `service uninstall` sub-commands
|
||||
- Your default network interface will be updated to use the listener started by the service
|
||||
- All OS DNS queries will be sent to the listener
|
||||
Linux or Macos
|
||||
```shell
|
||||
sudo ctrld service start
|
||||
```
|
||||
|
||||
# Configuration
|
||||
See [Configuration Docs](docs/config.md).
|
||||
`ctrld` can be configured in variety of different ways, which include: API, local config file or via cli launch args.
|
||||
|
||||
## Example
|
||||
- Start `listener.0` on 127.0.0.1:53
|
||||
- Accept queries from any source address
|
||||
- Send all queries to `upstream.0` via DoH protocol
|
||||
## API Based Auto Configuration
|
||||
Application can be started with a specific Control D resolver config, instead of the default one. Simply supply your Resolver ID with a `--cd` flag, when using the `start` (service) mode. This mode is used when the 1 liner installer command from the Control D onboarding guide is executed.
|
||||
|
||||
### Default Config
|
||||
The following command will use your own personal Control D Device resolver, and start the application in service mode. Your resolver ID is displayed on the "Show Resolvers" screen for the relevant Control D Endpoint.
|
||||
|
||||
Windows (Admin Shell)
|
||||
```shell
|
||||
ctrld.exe start --cd abcd1234
|
||||
```
|
||||
|
||||
Linux or Macos
|
||||
```shell
|
||||
sudo ctrld start --cd abcd1234
|
||||
```
|
||||
|
||||
Once you run the above command, the following things will happen:
|
||||
- You resolver configuration will be fetched from the API, and config file templated with the resolver data
|
||||
- Application will start as a service, and keep running (even after reboot) until you run the `stop` or `uninstall` sub-commands
|
||||
- All physical network interface will be updated to use the listener started by the service
|
||||
- All DNS queries will be sent to the listener
|
||||
|
||||
## Manual Configuration
|
||||
`ctrld` is entirely config driven and can be configured in many different ways, please see [Configuration Docs](docs/config.md).
|
||||
|
||||
### Example
|
||||
```toml
|
||||
[listener]
|
||||
|
||||
[listener.0]
|
||||
ip = "127.0.0.1"
|
||||
ip = '0.0.0.0'
|
||||
port = 53
|
||||
restricted = false
|
||||
|
||||
[network]
|
||||
|
||||
@@ -212,10 +216,6 @@ See [Configuration Docs](docs/config.md).
|
||||
cidrs = ["0.0.0.0/0"]
|
||||
name = "Network 0"
|
||||
|
||||
[service]
|
||||
log_level = "info"
|
||||
log_path = ""
|
||||
|
||||
[upstream]
|
||||
|
||||
[upstream.0]
|
||||
@@ -224,26 +224,26 @@ See [Configuration Docs](docs/config.md).
|
||||
name = "Control D - Anti-Malware"
|
||||
timeout = 5000
|
||||
type = "doh"
|
||||
|
||||
[upstream.1]
|
||||
bootstrap_ip = "76.76.2.11"
|
||||
endpoint = "p2.freedns.controld.com"
|
||||
name = "Control D - No Ads"
|
||||
timeout = 3000
|
||||
type = "doq"
|
||||
|
||||
```
|
||||
|
||||
## Advanced Configuration
|
||||
The above is the most basic example, which will work out of the box. If you're looking to do advanced configurations using policies, see [Configuration Docs](docs/config.md) for complete documentation of the config file.
|
||||
The above basic config will:
|
||||
- Start listener on 0.0.0.0:53
|
||||
- Accept queries from any source address
|
||||
- Send all queries to `https://freedns.controld.com/p1` using DoH protocol
|
||||
|
||||
You can also supply configuration via launch argeuments, in [Ephemeral Mode](docs/ephemeral_mode.md).
|
||||
## CLI Args
|
||||
If you're unable to use a config file, `ctrld` can be be supplied with basic configuration via launch arguments, in [Ephemeral Mode](docs/ephemeral_mode.md).
|
||||
|
||||
### Example
|
||||
```
|
||||
ctrld run --listen=127.0.0.1:53 --primary_upstream=https://freedns.controld.com/p2 --secondary_upstream=10.0.10.1:53 --domains=*.company.int,very-secure.local --log /path/to/log.log
|
||||
```
|
||||
|
||||
The above will start a foreground process and:
|
||||
- Listen on `127.0.0.1:53` for DNS queries
|
||||
- Forward all queries to `https://freedns.controld.com/p2` using DoH protocol, while...
|
||||
- Excluding `*.company.int` and `very-secure.local` matching queries, that are forwarded to `10.0.10.1:53`
|
||||
- Write a debug log to `/path/to/log.log`
|
||||
|
||||
## Contributing
|
||||
See [Contribution Guideline](./docs/contributing.md)
|
||||
|
||||
## Roadmap
|
||||
The following functionality is on the roadmap and will be available in future releases.
|
||||
- Prometheus metrics exporter
|
||||
- DNS intercept mode
|
||||
- Support for more routers (let us know which ones)
|
||||
|
||||
@@ -5,7 +5,18 @@ type ClientInfoCtxKey struct{}
|
||||
|
||||
// ClientInfo represents ctrld's clients information.
|
||||
type ClientInfo struct {
|
||||
Mac string
|
||||
IP string
|
||||
Hostname string
|
||||
Mac string
|
||||
IP string
|
||||
Hostname string
|
||||
Self bool
|
||||
ClientIDPref string
|
||||
}
|
||||
|
||||
// LeaseFileFormat specifies the format of DHCP lease file.
|
||||
type LeaseFileFormat string
|
||||
|
||||
const (
|
||||
Dnsmasq LeaseFileFormat = "dnsmasq"
|
||||
IscDhcpd LeaseFileFormat = "isc-dhcpd"
|
||||
KeaDHCP4 LeaseFileFormat = "kea-dhcp4"
|
||||
)
|
||||
|
||||
10
cmd/cli/ad_others.go
Normal file
10
cmd/cli/ad_others.go
Normal file
@@ -0,0 +1,10 @@
|
||||
//go:build !windows
|
||||
|
||||
package cli
|
||||
|
||||
import (
|
||||
"github.com/Control-D-Inc/ctrld"
|
||||
)
|
||||
|
||||
// addExtraSplitDnsRule adds split DNS rule if present.
|
||||
func addExtraSplitDnsRule(_ *ctrld.Config) bool { return false }
|
||||
73
cmd/cli/ad_windows.go
Normal file
73
cmd/cli/ad_windows.go
Normal file
@@ -0,0 +1,73 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"io"
|
||||
"log"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/microsoft/wmi/pkg/base/host"
|
||||
hh "github.com/microsoft/wmi/pkg/hardware/host"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld"
|
||||
)
|
||||
|
||||
// addExtraSplitDnsRule adds split DNS rule for domain if it's part of active directory.
|
||||
func addExtraSplitDnsRule(cfg *ctrld.Config) bool {
|
||||
domain, err := getActiveDirectoryDomain()
|
||||
if err != nil {
|
||||
mainLog.Load().Debug().Msgf("Unable to get active directory domain: %v", err)
|
||||
return false
|
||||
}
|
||||
if domain == "" {
|
||||
mainLog.Load().Debug().Msg("No active directory domain found")
|
||||
return false
|
||||
}
|
||||
// Network rules are lowercase during toml config marshaling,
|
||||
// lowercase the domain here too for consistency.
|
||||
domain = strings.ToLower(domain)
|
||||
domainRuleAdded := addSplitDnsRule(cfg, domain)
|
||||
wildcardDomainRuleRuleAdded := addSplitDnsRule(cfg, "*."+strings.TrimPrefix(domain, "."))
|
||||
return domainRuleAdded || wildcardDomainRuleRuleAdded
|
||||
}
|
||||
|
||||
// addSplitDnsRule adds split-rule for given domain if there's no existed rule.
|
||||
// The return value indicates whether the split-rule was added or not.
|
||||
func addSplitDnsRule(cfg *ctrld.Config, domain string) bool {
|
||||
for n, lc := range cfg.Listener {
|
||||
if lc.Policy == nil {
|
||||
lc.Policy = &ctrld.ListenerPolicyConfig{}
|
||||
}
|
||||
for _, rule := range lc.Policy.Rules {
|
||||
if _, ok := rule[domain]; ok {
|
||||
mainLog.Load().Debug().Msgf("Split-rule %q already existed for listener.%s", domain, n)
|
||||
return false
|
||||
}
|
||||
}
|
||||
mainLog.Load().Debug().Msgf("Adding split-rule %q for listener.%s", domain, n)
|
||||
lc.Policy.Rules = append(lc.Policy.Rules, ctrld.Rule{domain: []string{}})
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// getActiveDirectoryDomain returns AD domain name of this computer.
|
||||
func getActiveDirectoryDomain() (string, error) {
|
||||
log.SetOutput(io.Discard)
|
||||
defer log.SetOutput(os.Stderr)
|
||||
whost := host.NewWmiLocalHost()
|
||||
cs, err := hh.GetComputerSystem(whost)
|
||||
if cs != nil {
|
||||
defer cs.Close()
|
||||
}
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
pod, err := cs.GetPropertyPartOfDomain()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if pod {
|
||||
return cs.GetPropertyDomain()
|
||||
}
|
||||
return "", nil
|
||||
}
|
||||
71
cmd/cli/ad_windows_test.go
Normal file
71
cmd/cli/ad_windows_test.go
Normal file
@@ -0,0 +1,71 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld"
|
||||
"github.com/Control-D-Inc/ctrld/testhelper"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func Test_getActiveDirectoryDomain(t *testing.T) {
|
||||
start := time.Now()
|
||||
domain, err := getActiveDirectoryDomain()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Logf("Using Windows API takes: %d", time.Since(start).Milliseconds())
|
||||
|
||||
start = time.Now()
|
||||
domainPowershell, err := getActiveDirectoryDomainPowershell()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Logf("Using Powershell takes: %d", time.Since(start).Milliseconds())
|
||||
|
||||
if domain != domainPowershell {
|
||||
t.Fatalf("result mismatch, want: %v, got: %v", domainPowershell, domain)
|
||||
}
|
||||
}
|
||||
|
||||
func getActiveDirectoryDomainPowershell() (string, error) {
|
||||
cmd := "$obj = Get-WmiObject Win32_ComputerSystem; if ($obj.PartOfDomain) { $obj.Domain }"
|
||||
output, err := powershell(cmd)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to get domain name: %w, output:\n\n%s", err, string(output))
|
||||
}
|
||||
return string(output), nil
|
||||
}
|
||||
|
||||
func Test_addSplitDnsRule(t *testing.T) {
|
||||
newCfg := func(domains ...string) *ctrld.Config {
|
||||
cfg := testhelper.SampleConfig(t)
|
||||
lc := cfg.Listener["0"]
|
||||
for _, domain := range domains {
|
||||
lc.Policy.Rules = append(lc.Policy.Rules, ctrld.Rule{domain: []string{}})
|
||||
}
|
||||
return cfg
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
cfg *ctrld.Config
|
||||
domain string
|
||||
added bool
|
||||
}{
|
||||
{"added", newCfg(), "example.com", true},
|
||||
{"TLD existed", newCfg("example.com"), "*.example.com", true},
|
||||
{"wildcard existed", newCfg("*.example.com"), "example.com", true},
|
||||
{"not added TLD", newCfg("example.com", "*.example.com"), "example.com", false},
|
||||
{"not added wildcard", newCfg("example.com", "*.example.com"), "*.example.com", false},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
added := addSplitDnsRule(tc.cfg, tc.domain)
|
||||
assert.Equal(t, tc.added, added)
|
||||
})
|
||||
}
|
||||
}
|
||||
5
cmd/cli/cgo.go
Normal file
5
cmd/cli/cgo.go
Normal file
@@ -0,0 +1,5 @@
|
||||
//go:build cgo
|
||||
|
||||
package cli
|
||||
|
||||
const cgoEnabled = true
|
||||
1903
cmd/cli/cli.go
Normal file
1903
cmd/cli/cli.go
Normal file
File diff suppressed because it is too large
Load Diff
46
cmd/cli/cli_test.go
Normal file
46
cmd/cli/cli_test.go
Normal file
@@ -0,0 +1,46 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func Test_writeConfigFile(t *testing.T) {
|
||||
tmpdir := t.TempDir()
|
||||
// simulate --config CLI flag by setting configPath manually.
|
||||
configPath = filepath.Join(tmpdir, "ctrld.toml")
|
||||
_, err := os.Stat(configPath)
|
||||
assert.True(t, os.IsNotExist(err))
|
||||
|
||||
assert.NoError(t, writeConfigFile(&cfg))
|
||||
|
||||
_, err = os.Stat(configPath)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func Test_isStableVersion(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
ver string
|
||||
isStable bool
|
||||
}{
|
||||
{"stable", "v1.3.5", true},
|
||||
{"pre", "v1.3.5-next", false},
|
||||
{"pre with commit hash", "v1.3.5-next-asdf", false},
|
||||
{"dev", "dev", false},
|
||||
{"empty", "dev", false},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
if got := isStableVersion(tc.ver); got != tc.isStable {
|
||||
t.Errorf("unexpected result for %s, want: %v, got: %v", tc.ver, tc.isStable, got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
141
cmd/cli/commands_clients.go
Normal file
141
cmd/cli/commands_clients.go
Normal file
@@ -0,0 +1,141 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/kardianos/service"
|
||||
"github.com/olekukonko/tablewriter"
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld/internal/clientinfo"
|
||||
)
|
||||
|
||||
// ClientsCommand handles clients-related operations
|
||||
type ClientsCommand struct {
|
||||
controlClient *controlClient
|
||||
}
|
||||
|
||||
// NewClientsCommand creates a new clients command handler
|
||||
func NewClientsCommand() (*ClientsCommand, error) {
|
||||
dir, err := socketDir()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to find ctrld home dir: %w", err)
|
||||
}
|
||||
|
||||
cc := newControlClient(filepath.Join(dir, ctrldControlUnixSock))
|
||||
return &ClientsCommand{
|
||||
controlClient: cc,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ListClients lists all connected clients
|
||||
func (cc *ClientsCommand) ListClients(cmd *cobra.Command, args []string) error {
|
||||
// Check service status first
|
||||
sc := NewServiceCommand()
|
||||
s, _, err := sc.initializeServiceManager()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
status, err := s.Status()
|
||||
if errors.Is(err, service.ErrNotInstalled) {
|
||||
mainLog.Load().Warn().Msg("Service not installed")
|
||||
return nil
|
||||
}
|
||||
if status == service.StatusStopped {
|
||||
mainLog.Load().Warn().Msg("Service is not running")
|
||||
return nil
|
||||
}
|
||||
|
||||
resp, err := cc.controlClient.post(listClientsPath, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get clients: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
var clients []*clientinfo.Client
|
||||
if err := json.NewDecoder(resp.Body).Decode(&clients); err != nil {
|
||||
return fmt.Errorf("failed to decode clients result: %w", err)
|
||||
}
|
||||
|
||||
map2Slice := func(m map[string]struct{}) []string {
|
||||
s := make([]string, 0, len(m))
|
||||
for k := range m {
|
||||
if k == "" { // skip empty source from output.
|
||||
continue
|
||||
}
|
||||
s = append(s, k)
|
||||
}
|
||||
sort.Strings(s)
|
||||
return s
|
||||
}
|
||||
|
||||
// If metrics is enabled, server set this for all clients, so we can check only the first one.
|
||||
// Ideally, we may have a field in response to indicate that query count should be shown, but
|
||||
// it would break earlier version of ctrld, which only look list of clients in response.
|
||||
withQueryCount := len(clients) > 0 && clients[0].IncludeQueryCount
|
||||
data := make([][]string, len(clients))
|
||||
for i, c := range clients {
|
||||
row := []string{
|
||||
c.IP.String(),
|
||||
c.Hostname,
|
||||
c.Mac,
|
||||
strings.Join(map2Slice(c.Source), ","),
|
||||
}
|
||||
if withQueryCount {
|
||||
row = append(row, strconv.FormatInt(c.QueryCount, 10))
|
||||
}
|
||||
data[i] = row
|
||||
}
|
||||
|
||||
table := tablewriter.NewWriter(os.Stdout)
|
||||
headers := []string{"IP", "Hostname", "Mac", "Discovered"}
|
||||
if withQueryCount {
|
||||
headers = append(headers, "Queries")
|
||||
}
|
||||
table.SetHeader(headers)
|
||||
table.SetAutoFormatHeaders(false)
|
||||
table.AppendBulk(data)
|
||||
table.Render()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// InitClientsCmd creates the clients command with proper logic
|
||||
func InitClientsCmd(rootCmd *cobra.Command) *cobra.Command {
|
||||
listClientsCmd := &cobra.Command{
|
||||
Use: "list",
|
||||
Short: "List clients that ctrld discovered",
|
||||
Args: cobra.NoArgs,
|
||||
PreRun: func(cmd *cobra.Command, args []string) {
|
||||
checkHasElevatedPrivilege()
|
||||
},
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
cc, err := NewClientsCommand()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return cc.ListClients(cmd, args)
|
||||
},
|
||||
}
|
||||
|
||||
clientsCmd := &cobra.Command{
|
||||
Use: "clients",
|
||||
Short: "Manage clients",
|
||||
Args: cobra.OnlyValidArgs,
|
||||
ValidArgs: []string{
|
||||
listClientsCmd.Use,
|
||||
},
|
||||
}
|
||||
clientsCmd.AddCommand(listClientsCmd)
|
||||
rootCmd.AddCommand(clientsCmd)
|
||||
|
||||
return clientsCmd
|
||||
}
|
||||
87
cmd/cli/commands_interfaces.go
Normal file
87
cmd/cli/commands_interfaces.go
Normal file
@@ -0,0 +1,87 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
// InterfacesCommand handles interfaces-related operations
|
||||
type InterfacesCommand struct{}
|
||||
|
||||
// NewInterfacesCommand creates a new interfaces command handler
|
||||
func NewInterfacesCommand() (*InterfacesCommand, error) {
|
||||
return &InterfacesCommand{}, nil
|
||||
}
|
||||
|
||||
// ListInterfaces lists all network interfaces
|
||||
func (ic *InterfacesCommand) ListInterfaces(cmd *cobra.Command, args []string) error {
|
||||
withEachPhysicalInterfaces("", "Interface list", func(i *net.Interface) error {
|
||||
fmt.Printf("Index : %d\n", i.Index)
|
||||
fmt.Printf("Name : %s\n", i.Name)
|
||||
var status string
|
||||
if i.Flags&net.FlagUp != 0 {
|
||||
status = "Up"
|
||||
} else {
|
||||
status = "Down"
|
||||
}
|
||||
fmt.Printf("Status: %s\n", status)
|
||||
addrs, _ := i.Addrs()
|
||||
for i, ipaddr := range addrs {
|
||||
if i == 0 {
|
||||
fmt.Printf("Addrs : %v\n", ipaddr)
|
||||
continue
|
||||
}
|
||||
fmt.Printf(" %v\n", ipaddr)
|
||||
}
|
||||
nss, err := currentStaticDNS(i)
|
||||
if err != nil {
|
||||
mainLog.Load().Warn().Err(err).Msg("Failed to get DNS")
|
||||
}
|
||||
if len(nss) == 0 {
|
||||
nss = currentDNS(i)
|
||||
}
|
||||
for i, dns := range nss {
|
||||
if i == 0 {
|
||||
fmt.Printf("DNS : %s\n", dns)
|
||||
continue
|
||||
}
|
||||
fmt.Printf(" : %s\n", dns)
|
||||
}
|
||||
println()
|
||||
return nil
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
// InitInterfacesCmd creates the interfaces command with proper logic
|
||||
func InitInterfacesCmd(_ *cobra.Command) *cobra.Command {
|
||||
listInterfacesCmd := &cobra.Command{
|
||||
Use: "list",
|
||||
Short: "List network interfaces",
|
||||
Args: cobra.NoArgs,
|
||||
PreRun: func(cmd *cobra.Command, args []string) {
|
||||
checkHasElevatedPrivilege()
|
||||
},
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
ic, err := NewInterfacesCommand()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return ic.ListInterfaces(cmd, args)
|
||||
},
|
||||
}
|
||||
|
||||
interfacesCmd := &cobra.Command{
|
||||
Use: "interfaces",
|
||||
Short: "Manage network interfaces",
|
||||
Args: cobra.OnlyValidArgs,
|
||||
ValidArgs: []string{
|
||||
listInterfacesCmd.Use,
|
||||
},
|
||||
}
|
||||
interfacesCmd.AddCommand(listInterfacesCmd)
|
||||
|
||||
return interfacesCmd
|
||||
}
|
||||
175
cmd/cli/commands_log.go
Normal file
175
cmd/cli/commands_log.go
Normal file
@@ -0,0 +1,175 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/docker/go-units"
|
||||
"github.com/kardianos/service"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
// LogCommand handles log-related operations
|
||||
type LogCommand struct {
|
||||
controlClient *controlClient
|
||||
}
|
||||
|
||||
// NewLogCommand creates a new log command handler
|
||||
func NewLogCommand() (*LogCommand, error) {
|
||||
dir, err := socketDir()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to find ctrld home dir: %w", err)
|
||||
}
|
||||
|
||||
cc := newControlClient(filepath.Join(dir, ctrldControlUnixSock))
|
||||
return &LogCommand{
|
||||
controlClient: cc,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// warnRuntimeLoggingNotEnabled logs a warning about runtime logging not being enabled
|
||||
func (lc *LogCommand) warnRuntimeLoggingNotEnabled() {
|
||||
mainLog.Load().Warn().Msg("Runtime debug logging is not enabled")
|
||||
mainLog.Load().Warn().Msg(`ctrld may be running without "--cd" flag or logging is already enabled`)
|
||||
}
|
||||
|
||||
// SendLogs sends runtime debug logs to ControlD
|
||||
func (lc *LogCommand) SendLogs(cmd *cobra.Command, args []string) error {
|
||||
sc := NewServiceCommand()
|
||||
s, _, err := sc.initializeServiceManager()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
status, err := s.Status()
|
||||
if errors.Is(err, service.ErrNotInstalled) {
|
||||
mainLog.Load().Warn().Msg("Service not installed")
|
||||
return nil
|
||||
}
|
||||
if status == service.StatusStopped {
|
||||
mainLog.Load().Warn().Msg("Service is not running")
|
||||
return nil
|
||||
}
|
||||
|
||||
resp, err := lc.controlClient.post(sendLogsPath, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to send logs: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
switch resp.StatusCode {
|
||||
case http.StatusServiceUnavailable:
|
||||
mainLog.Load().Warn().Msg("Runtime logs could only be sent once per minute")
|
||||
return nil
|
||||
case http.StatusMovedPermanently:
|
||||
lc.warnRuntimeLoggingNotEnabled()
|
||||
return nil
|
||||
}
|
||||
|
||||
var logs logSentResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&logs); err != nil {
|
||||
return fmt.Errorf("failed to decode sent logs result: %w", err)
|
||||
}
|
||||
|
||||
if logs.Error != "" {
|
||||
return fmt.Errorf("failed to send logs: %s", logs.Error)
|
||||
}
|
||||
|
||||
mainLog.Load().Notice().Msgf("Sent %s of runtime logs", units.BytesSize(float64(logs.Size)))
|
||||
return nil
|
||||
}
|
||||
|
||||
// ViewLogs views current runtime debug logs
|
||||
func (lc *LogCommand) ViewLogs(cmd *cobra.Command, args []string) error {
|
||||
sc := NewServiceCommand()
|
||||
s, _, err := sc.initializeServiceManager()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
status, err := s.Status()
|
||||
if errors.Is(err, service.ErrNotInstalled) {
|
||||
mainLog.Load().Warn().Msg("Service not installed")
|
||||
return nil
|
||||
}
|
||||
if status == service.StatusStopped {
|
||||
mainLog.Load().Warn().Msg("Service is not running")
|
||||
return nil
|
||||
}
|
||||
|
||||
resp, err := lc.controlClient.post(viewLogsPath, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get logs: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
switch resp.StatusCode {
|
||||
case http.StatusMovedPermanently:
|
||||
lc.warnRuntimeLoggingNotEnabled()
|
||||
return nil
|
||||
case http.StatusBadRequest:
|
||||
mainLog.Load().Warn().Msg("Runtime debug logs are not available")
|
||||
buf, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
mainLog.Load().Fatal().Err(err).Msg("Failed to read response body")
|
||||
}
|
||||
mainLog.Load().Warn().Msgf("ctrld process response:\n\n%s\n", string(buf))
|
||||
return nil
|
||||
case http.StatusOK:
|
||||
}
|
||||
|
||||
var logs logViewResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&logs); err != nil {
|
||||
return fmt.Errorf("failed to decode view logs result: %w", err)
|
||||
}
|
||||
|
||||
fmt.Print(logs.Data)
|
||||
return nil
|
||||
}
|
||||
|
||||
// InitLogCmd creates the log command with proper logic
|
||||
func InitLogCmd(rootCmd *cobra.Command) *cobra.Command {
|
||||
lc, err := NewLogCommand()
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("failed to create log command: %v", err))
|
||||
}
|
||||
|
||||
logSendCmd := &cobra.Command{
|
||||
Use: "send",
|
||||
Short: "Send runtime debug logs to ControlD",
|
||||
Args: cobra.NoArgs,
|
||||
PreRun: func(cmd *cobra.Command, args []string) {
|
||||
checkHasElevatedPrivilege()
|
||||
},
|
||||
RunE: lc.SendLogs,
|
||||
}
|
||||
|
||||
logViewCmd := &cobra.Command{
|
||||
Use: "view",
|
||||
Short: "View current runtime debug logs",
|
||||
Args: cobra.NoArgs,
|
||||
PreRun: func(cmd *cobra.Command, args []string) {
|
||||
checkHasElevatedPrivilege()
|
||||
},
|
||||
RunE: lc.ViewLogs,
|
||||
}
|
||||
|
||||
logCmd := &cobra.Command{
|
||||
Use: "log",
|
||||
Short: "Manage runtime debug logs",
|
||||
Args: cobra.OnlyValidArgs,
|
||||
ValidArgs: []string{
|
||||
logSendCmd.Use,
|
||||
logViewCmd.Use,
|
||||
},
|
||||
}
|
||||
logCmd.AddCommand(logSendCmd)
|
||||
logCmd.AddCommand(logViewCmd)
|
||||
rootCmd.AddCommand(logCmd)
|
||||
|
||||
return logCmd
|
||||
}
|
||||
59
cmd/cli/commands_run.go
Normal file
59
cmd/cli/commands_run.go
Normal file
@@ -0,0 +1,59 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld"
|
||||
)
|
||||
|
||||
// RunCommand handles run-related operations
|
||||
type RunCommand struct {
|
||||
// Add any dependencies here if needed in the future
|
||||
}
|
||||
|
||||
// NewRunCommand creates a new run command handler
|
||||
func NewRunCommand() *RunCommand {
|
||||
return &RunCommand{}
|
||||
}
|
||||
|
||||
// Run implements the logic for the run command
|
||||
func (rc *RunCommand) Run(cmd *cobra.Command, args []string) {
|
||||
RunCobraCommand(cmd)
|
||||
}
|
||||
|
||||
// InitRunCmd creates the run command with proper logic
|
||||
func InitRunCmd(rootCmd *cobra.Command) *cobra.Command {
|
||||
rc := NewRunCommand()
|
||||
|
||||
runCmd := &cobra.Command{
|
||||
Use: "run",
|
||||
Short: "Run the DNS proxy server",
|
||||
Args: cobra.NoArgs,
|
||||
Run: rc.Run,
|
||||
}
|
||||
runCmd.Flags().BoolVarP(&daemon, "daemon", "d", false, "Run as daemon")
|
||||
runCmd.Flags().StringVarP(&configPath, "config", "c", "", "Path to config file")
|
||||
runCmd.Flags().StringVarP(&configBase64, "base64_config", "", "", "Base64 encoded config")
|
||||
runCmd.Flags().StringVarP(&listenAddress, "listen", "", "", "Listener address and port, in format: address:port")
|
||||
runCmd.Flags().StringVarP(&primaryUpstream, "primary_upstream", "", "", "Primary upstream endpoint")
|
||||
runCmd.Flags().StringVarP(&secondaryUpstream, "secondary_upstream", "", "", "Secondary upstream endpoint")
|
||||
runCmd.Flags().StringSliceVarP(&domains, "domains", "", nil, "List of domain to apply in a split DNS policy")
|
||||
runCmd.Flags().StringVarP(&logPath, "log", "", "", "Path to log file")
|
||||
runCmd.Flags().IntVarP(&cacheSize, "cache_size", "", 0, "Enable cache with size items")
|
||||
runCmd.Flags().StringVarP(&cdUID, cdUidFlagName, "", "", "Control D resolver uid")
|
||||
runCmd.Flags().StringVarP(&cdOrg, cdOrgFlagName, "", "", "Control D provision token")
|
||||
runCmd.Flags().StringVarP(&customHostname, customHostnameFlagName, "", "", "Custom hostname passed to ControlD API")
|
||||
runCmd.Flags().BoolVarP(&cdDev, "dev", "", false, "Use Control D dev resolver/domain")
|
||||
_ = runCmd.Flags().MarkHidden("dev")
|
||||
runCmd.Flags().StringVarP(&homedir, "homedir", "", "", "")
|
||||
_ = runCmd.Flags().MarkHidden("homedir")
|
||||
runCmd.Flags().StringVarP(&iface, "iface", "", "", `Update DNS setting for iface, "auto" means the default interface gateway`)
|
||||
_ = runCmd.Flags().MarkHidden("iface")
|
||||
runCmd.Flags().StringVarP(&cdUpstreamProto, "proto", "", ctrld.ResolverTypeDOH, `Control D upstream type, either "doh" or "doh3"`)
|
||||
runCmd.Flags().BoolVarP(&rfc1918, "rfc1918", "", false, "Listen on RFC1918 addresses when 127.0.0.1 is the only listener")
|
||||
|
||||
runCmd.FParseErrWhitelist = cobra.FParseErrWhitelist{UnknownFlags: true}
|
||||
rootCmd.AddCommand(runCmd)
|
||||
|
||||
return runCmd
|
||||
}
|
||||
256
cmd/cli/commands_service.go
Normal file
256
cmd/cli/commands_service.go
Normal file
@@ -0,0 +1,256 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"runtime"
|
||||
|
||||
"github.com/kardianos/service"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
// filterEmptyStrings removes empty strings from a slice
|
||||
// This is used to clean up command line arguments and configuration values
|
||||
func filterEmptyStrings(slice []string) []string {
|
||||
var result []string
|
||||
for _, s := range slice {
|
||||
if s != "" {
|
||||
result = append(result, s)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// ServiceCommand handles service-related operations
|
||||
// This encapsulates all service management functionality for the CLI
|
||||
type ServiceCommand struct {
|
||||
serviceManager *ServiceManager
|
||||
}
|
||||
|
||||
// initializeServiceManager creates a service manager with default configuration
|
||||
// This sets up the basic service infrastructure needed for all service operations
|
||||
func (sc *ServiceCommand) initializeServiceManager() (service.Service, *prog, error) {
|
||||
svcConfig := sc.createServiceConfig()
|
||||
return sc.initializeServiceManagerWithServiceConfig(svcConfig)
|
||||
}
|
||||
|
||||
// initializeServiceManagerWithServiceConfig creates a service manager with the given configuration
|
||||
// This allows for custom service configuration while maintaining the same initialization pattern
|
||||
func (sc *ServiceCommand) initializeServiceManagerWithServiceConfig(svcConfig *service.Config) (service.Service, *prog, error) {
|
||||
p := &prog{}
|
||||
|
||||
s, err := sc.newService(p, svcConfig)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to create service: %w", err)
|
||||
}
|
||||
|
||||
sc.serviceManager = &ServiceManager{prog: p, svc: s}
|
||||
return s, p, nil
|
||||
}
|
||||
|
||||
// newService creates a new service instance using the provided program and configuration.
|
||||
// This abstracts the service creation process for different operating systems
|
||||
func (sc *ServiceCommand) newService(p *prog, svcConfig *service.Config) (service.Service, error) {
|
||||
s, err := newService(p, svcConfig)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create service: %w", err)
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// NewServiceCommand creates a new service command handler
|
||||
// This provides a clean factory method for creating service command instances
|
||||
func NewServiceCommand() *ServiceCommand {
|
||||
return &ServiceCommand{}
|
||||
}
|
||||
|
||||
// createServiceConfig creates a properly initialized service configuration
|
||||
// This ensures consistent service naming and description across all platforms
|
||||
func (sc *ServiceCommand) createServiceConfig() *service.Config {
|
||||
return &service.Config{
|
||||
Name: ctrldServiceName,
|
||||
DisplayName: "Control-D Helper Service",
|
||||
Description: "A highly configurable, multi-protocol DNS forwarding proxy",
|
||||
Option: service.KeyValue{},
|
||||
}
|
||||
}
|
||||
|
||||
// InitServiceCmd creates the service command with proper logic and aliases
|
||||
// This sets up all service-related subcommands with appropriate permissions and flags
|
||||
func InitServiceCmd(rootCmd *cobra.Command) *cobra.Command {
|
||||
// Create service command handlers
|
||||
sc := NewServiceCommand()
|
||||
|
||||
startCmd, startCmdAlias := createStartCommands(sc)
|
||||
rootCmd.AddCommand(startCmdAlias)
|
||||
|
||||
// Stop command
|
||||
stopCmd := &cobra.Command{
|
||||
Use: "stop",
|
||||
Short: "Stop the ctrld service",
|
||||
Args: cobra.NoArgs,
|
||||
PreRun: func(cmd *cobra.Command, args []string) {
|
||||
checkHasElevatedPrivilege()
|
||||
},
|
||||
RunE: sc.Stop,
|
||||
}
|
||||
stopCmd.Flags().StringVarP(&iface, "iface", "", "", `Reset DNS setting for iface, "auto" means the default interface gateway`)
|
||||
stopCmd.Flags().Int64VarP(&deactivationPin, "pin", "", defaultDeactivationPin, `Pin code for stopping ctrld`)
|
||||
_ = stopCmd.Flags().MarkHidden("pin")
|
||||
|
||||
// Restart command
|
||||
restartCmd := &cobra.Command{
|
||||
Use: "restart",
|
||||
Short: "Restart the ctrld service",
|
||||
Args: cobra.NoArgs,
|
||||
PreRun: func(cmd *cobra.Command, args []string) {
|
||||
checkHasElevatedPrivilege()
|
||||
},
|
||||
RunE: sc.Restart,
|
||||
}
|
||||
|
||||
// Status command
|
||||
statusCmd := &cobra.Command{
|
||||
Use: "status",
|
||||
Short: "Show status of the ctrld service",
|
||||
Args: cobra.NoArgs,
|
||||
RunE: sc.Status,
|
||||
}
|
||||
if runtime.GOOS == "darwin" {
|
||||
// On darwin, running status command without privileges may return wrong information.
|
||||
statusCmd.PreRun = func(cmd *cobra.Command, args []string) {
|
||||
checkHasElevatedPrivilege()
|
||||
}
|
||||
}
|
||||
|
||||
// Reload command
|
||||
reloadCmd := &cobra.Command{
|
||||
Use: "reload",
|
||||
Short: "Reload the ctrld service",
|
||||
Args: cobra.NoArgs,
|
||||
PreRun: func(cmd *cobra.Command, args []string) {
|
||||
checkHasElevatedPrivilege()
|
||||
},
|
||||
RunE: sc.Reload,
|
||||
}
|
||||
|
||||
// Uninstall command
|
||||
uninstallCmd := &cobra.Command{
|
||||
Use: "uninstall",
|
||||
Short: "Stop and uninstall the ctrld service",
|
||||
Long: `Stop and uninstall the ctrld service.
|
||||
|
||||
NOTE: Uninstalling will set DNS to values provided by DHCP.`,
|
||||
Args: cobra.NoArgs,
|
||||
PreRun: func(cmd *cobra.Command, args []string) {
|
||||
checkHasElevatedPrivilege()
|
||||
},
|
||||
RunE: sc.Uninstall,
|
||||
}
|
||||
uninstallCmd.Flags().StringVarP(&iface, "iface", "", "", `Reset DNS setting for iface, "auto" means the default interface gateway`)
|
||||
uninstallCmd.Flags().Int64VarP(&deactivationPin, "pin", "", defaultDeactivationPin, `Pin code for stopping ctrld`)
|
||||
_ = uninstallCmd.Flags().MarkHidden("pin")
|
||||
uninstallCmd.Flags().BoolVarP(&cleanup, "cleanup", "", false, `Removing ctrld binary and config files`)
|
||||
|
||||
// Interfaces command - use the existing InitInterfacesCmd function
|
||||
interfacesCmd := InitInterfacesCmd(rootCmd)
|
||||
|
||||
stopCmdAlias := &cobra.Command{
|
||||
PreRun: func(cmd *cobra.Command, args []string) {
|
||||
checkHasElevatedPrivilege()
|
||||
},
|
||||
Use: "stop",
|
||||
Short: "Quick stop service and remove DNS from interface",
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
if !cmd.Flags().Changed("iface") {
|
||||
os.Args = append(os.Args, "--iface="+ifaceStartStop)
|
||||
}
|
||||
iface = ifaceStartStop
|
||||
return stopCmd.RunE(cmd, args)
|
||||
},
|
||||
}
|
||||
stopCmdAlias.Flags().StringVarP(&ifaceStartStop, "iface", "", "auto", `Reset DNS setting for iface, "auto" means the default interface gateway`)
|
||||
stopCmdAlias.Flags().AddFlagSet(stopCmd.Flags())
|
||||
rootCmd.AddCommand(stopCmdAlias)
|
||||
|
||||
// Create aliases for other service commands
|
||||
restartCmdAlias := &cobra.Command{
|
||||
PreRun: func(cmd *cobra.Command, args []string) {
|
||||
checkHasElevatedPrivilege()
|
||||
},
|
||||
Use: "restart",
|
||||
Short: "Restart the ctrld service",
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
return restartCmd.RunE(cmd, args)
|
||||
},
|
||||
}
|
||||
rootCmd.AddCommand(restartCmdAlias)
|
||||
|
||||
reloadCmdAlias := &cobra.Command{
|
||||
PreRun: func(cmd *cobra.Command, args []string) {
|
||||
checkHasElevatedPrivilege()
|
||||
},
|
||||
Use: "reload",
|
||||
Short: "Reload the ctrld service",
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
return reloadCmd.RunE(cmd, args)
|
||||
},
|
||||
}
|
||||
rootCmd.AddCommand(reloadCmdAlias)
|
||||
|
||||
statusCmdAlias := &cobra.Command{
|
||||
Use: "status",
|
||||
Short: "Show status of the ctrld service",
|
||||
Args: cobra.NoArgs,
|
||||
RunE: statusCmd.RunE,
|
||||
}
|
||||
rootCmd.AddCommand(statusCmdAlias)
|
||||
|
||||
uninstallCmdAlias := &cobra.Command{
|
||||
PreRun: func(cmd *cobra.Command, args []string) {
|
||||
checkHasElevatedPrivilege()
|
||||
},
|
||||
Use: "uninstall",
|
||||
Short: "Stop and uninstall the ctrld service",
|
||||
Long: `Stop and uninstall the ctrld service.
|
||||
|
||||
NOTE: Uninstalling will set DNS to values provided by DHCP.`,
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
if !cmd.Flags().Changed("iface") {
|
||||
os.Args = append(os.Args, "--iface="+ifaceStartStop)
|
||||
}
|
||||
iface = ifaceStartStop
|
||||
return uninstallCmd.RunE(cmd, args)
|
||||
},
|
||||
}
|
||||
uninstallCmdAlias.Flags().StringVarP(&ifaceStartStop, "iface", "", "auto", `Reset DNS setting for iface, "auto" means the default interface gateway`)
|
||||
uninstallCmdAlias.Flags().AddFlagSet(uninstallCmd.Flags())
|
||||
rootCmd.AddCommand(uninstallCmdAlias)
|
||||
|
||||
// Create service command
|
||||
serviceCmd := &cobra.Command{
|
||||
Use: "service",
|
||||
Short: "Manage ctrld service",
|
||||
Args: cobra.OnlyValidArgs,
|
||||
}
|
||||
serviceCmd.ValidArgs = make([]string, 7)
|
||||
serviceCmd.ValidArgs[0] = startCmd.Use
|
||||
serviceCmd.ValidArgs[1] = stopCmd.Use
|
||||
serviceCmd.ValidArgs[2] = restartCmd.Use
|
||||
serviceCmd.ValidArgs[3] = reloadCmd.Use
|
||||
serviceCmd.ValidArgs[4] = statusCmd.Use
|
||||
serviceCmd.ValidArgs[5] = uninstallCmd.Use
|
||||
serviceCmd.ValidArgs[6] = interfacesCmd.Use
|
||||
|
||||
serviceCmd.AddCommand(startCmd)
|
||||
serviceCmd.AddCommand(stopCmd)
|
||||
serviceCmd.AddCommand(restartCmd)
|
||||
serviceCmd.AddCommand(reloadCmd)
|
||||
serviceCmd.AddCommand(statusCmd)
|
||||
serviceCmd.AddCommand(uninstallCmd)
|
||||
serviceCmd.AddCommand(interfacesCmd)
|
||||
|
||||
rootCmd.AddCommand(serviceCmd)
|
||||
|
||||
return serviceCmd
|
||||
}
|
||||
41
cmd/cli/commands_service_manager.go
Normal file
41
cmd/cli/commands_service_manager.go
Normal file
@@ -0,0 +1,41 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/kardianos/service"
|
||||
)
|
||||
|
||||
// dialSocketControlServerTimeout is the default timeout to wait when ping control server.
|
||||
const dialSocketControlServerTimeout = 30 * time.Second
|
||||
|
||||
// ServiceManager handles service operations
|
||||
type ServiceManager struct {
|
||||
prog *prog
|
||||
svc service.Service
|
||||
}
|
||||
|
||||
// NewServiceManager creates a new service manager
|
||||
func NewServiceManager() (*ServiceManager, error) {
|
||||
p := &prog{}
|
||||
|
||||
// Create a proper service configuration
|
||||
svcConfig := &service.Config{
|
||||
Name: ctrldServiceName,
|
||||
DisplayName: "Control-D Helper Service",
|
||||
Description: "A highly configurable, multi-protocol DNS forwarding proxy",
|
||||
Option: service.KeyValue{},
|
||||
}
|
||||
|
||||
s, err := newService(p, svcConfig)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create service: %w", err)
|
||||
}
|
||||
return &ServiceManager{prog: p, svc: s}, nil
|
||||
}
|
||||
|
||||
// Status returns the current service status
|
||||
func (sm *ServiceManager) Status() (service.Status, error) {
|
||||
return sm.svc.Status()
|
||||
}
|
||||
67
cmd/cli/commands_service_reload.go
Normal file
67
cmd/cli/commands_service_reload.go
Normal file
@@ -0,0 +1,67 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/kardianos/service"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
// Reload implements the logic from cmdReload.Run
|
||||
func (sc *ServiceCommand) Reload(cmd *cobra.Command, args []string) error {
|
||||
logger := mainLog.Load()
|
||||
logger.Debug().Msg("Service reload command started")
|
||||
|
||||
s, _, err := sc.initializeServiceManager()
|
||||
if err != nil {
|
||||
logger.Error().Err(err).Msg("Failed to initialize service manager")
|
||||
return err
|
||||
}
|
||||
|
||||
status, err := s.Status()
|
||||
if errors.Is(err, service.ErrNotInstalled) {
|
||||
logger.Warn().Msg("Service not installed")
|
||||
return nil
|
||||
}
|
||||
if status == service.StatusStopped {
|
||||
logger.Warn().Msg("Service is not running")
|
||||
return nil
|
||||
}
|
||||
|
||||
dir, err := socketDir()
|
||||
if err != nil {
|
||||
logger.Fatal().Err(err).Msg("Failed to find ctrld home dir")
|
||||
}
|
||||
|
||||
cc := newControlClient(filepath.Join(dir, ctrldControlUnixSock))
|
||||
resp, err := cc.post(reloadPath, nil)
|
||||
if err != nil {
|
||||
logger.Fatal().Err(err).Msg("Failed to send reload signal to ctrld")
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
switch resp.StatusCode {
|
||||
case http.StatusOK:
|
||||
logger.Notice().Msg("Service reloaded")
|
||||
case http.StatusCreated:
|
||||
logger.Warn().Msg("Service was reloaded, but new config requires service restart.")
|
||||
logger.Warn().Msg("Restarting service")
|
||||
if _, err := s.Status(); errors.Is(err, service.ErrNotInstalled) {
|
||||
logger.Warn().Msg("Service not installed")
|
||||
return nil
|
||||
}
|
||||
return sc.Restart(cmd, args)
|
||||
default:
|
||||
buf, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
logger.Fatal().Err(err).Msg("Could not read response from control server")
|
||||
}
|
||||
logger.Error().Err(err).Msgf("Failed to reload ctrld: %s", string(buf))
|
||||
}
|
||||
|
||||
logger.Debug().Msg("Service reload command completed")
|
||||
return nil
|
||||
}
|
||||
111
cmd/cli/commands_service_restart.go
Normal file
111
cmd/cli/commands_service_restart.go
Normal file
@@ -0,0 +1,111 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"github.com/kardianos/service"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
// Restart implements the logic from cmdRestart.Run
|
||||
func (sc *ServiceCommand) Restart(cmd *cobra.Command, args []string) error {
|
||||
logger := mainLog.Load()
|
||||
logger.Debug().Msg("Service restart command started")
|
||||
|
||||
readConfig(false)
|
||||
v.Unmarshal(&cfg)
|
||||
cdUID = curCdUID()
|
||||
cdMode := cdUID != ""
|
||||
|
||||
s, p, err := sc.initializeServiceManager()
|
||||
if err != nil {
|
||||
logger.Error().Err(err).Msg("Failed to initialize service manager")
|
||||
return err
|
||||
}
|
||||
|
||||
if _, err := s.Status(); errors.Is(err, service.ErrNotInstalled) {
|
||||
logger.Warn().Msg("Service not installed")
|
||||
return nil
|
||||
}
|
||||
|
||||
p.cfg = &cfg
|
||||
if iface == "" {
|
||||
iface = "auto"
|
||||
}
|
||||
p.preRun()
|
||||
if ir := runningIface(s); ir != nil {
|
||||
p.runningIface = ir.Name
|
||||
p.requiredMultiNICsConfig = ir.All
|
||||
}
|
||||
|
||||
initInteractiveLogging()
|
||||
|
||||
var validateConfigErr error
|
||||
if cdMode {
|
||||
logger.Debug().Msg("Validating ControlD remote config")
|
||||
validateConfigErr = doValidateCdRemoteConfig(cdUID, false)
|
||||
if validateConfigErr != nil {
|
||||
logger.Warn().Err(validateConfigErr).Msg("ControlD remote config validation failed")
|
||||
}
|
||||
}
|
||||
|
||||
if ir := runningIface(s); ir != nil {
|
||||
iface = ir.Name
|
||||
}
|
||||
|
||||
doRestart := func() bool {
|
||||
logger.Debug().Msg("Starting service restart sequence")
|
||||
|
||||
tasks := []task{
|
||||
{s.Stop, true, "Stop"},
|
||||
{func() error {
|
||||
// restore static DNS settings or DHCP
|
||||
p.resetDNS(false, true)
|
||||
return nil
|
||||
}, false, "Cleanup"},
|
||||
{func() error {
|
||||
time.Sleep(time.Second * 1)
|
||||
return nil
|
||||
}, false, "Waiting for service to stop"},
|
||||
}
|
||||
if !doTasks(tasks) {
|
||||
logger.Error().Msg("Service stop tasks failed")
|
||||
return false
|
||||
}
|
||||
tasks = []task{
|
||||
{s.Start, true, "Start"},
|
||||
}
|
||||
success := doTasks(tasks)
|
||||
if success {
|
||||
logger.Debug().Msg("Service restart sequence completed successfully")
|
||||
} else {
|
||||
logger.Error().Msg("Service restart sequence failed")
|
||||
}
|
||||
return success
|
||||
}
|
||||
|
||||
if doRestart() {
|
||||
if dir, err := socketDir(); err == nil {
|
||||
timeout := dialSocketControlServerTimeout
|
||||
if validateConfigErr != nil {
|
||||
timeout = 5 * time.Second
|
||||
}
|
||||
if cc := newSocketControlClientWithTimeout(context.TODO(), s, dir, timeout); cc != nil {
|
||||
_, _ = cc.post(ifacePath, nil)
|
||||
logger.Debug().Msg("Control server ping successful")
|
||||
} else {
|
||||
logger.Warn().Err(err).Msg("Service was restarted, but ctrld process may not be ready yet")
|
||||
}
|
||||
} else {
|
||||
logger.Warn().Err(err).Msg("Service was restarted, but could not ping the control server")
|
||||
}
|
||||
logger.Notice().Msg("Service restarted")
|
||||
} else {
|
||||
logger.Error().Msg("Service restart failed")
|
||||
}
|
||||
|
||||
logger.Debug().Msg("Service restart command completed")
|
||||
return nil
|
||||
}
|
||||
387
cmd/cli/commands_service_start.go
Normal file
387
cmd/cli/commands_service_start.go
Normal file
@@ -0,0 +1,387 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"github.com/kardianos/service"
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld"
|
||||
)
|
||||
|
||||
// Start implements the logic from cmdStart.Run
|
||||
func (sc *ServiceCommand) Start(cmd *cobra.Command, args []string) error {
|
||||
logger := mainLog.Load()
|
||||
logger.Debug().Msg("Service start command started")
|
||||
|
||||
checkStrFlagEmpty(cmd, cdUidFlagName)
|
||||
checkStrFlagEmpty(cmd, cdOrgFlagName)
|
||||
validateCdAndNextDNSFlags()
|
||||
|
||||
svcConfig := sc.createServiceConfig()
|
||||
osArgs := os.Args[2:]
|
||||
osArgs = filterEmptyStrings(osArgs)
|
||||
if os.Args[1] == "service" {
|
||||
osArgs = os.Args[3:]
|
||||
}
|
||||
setDependencies(svcConfig)
|
||||
svcConfig.Arguments = append([]string{"run"}, osArgs...)
|
||||
|
||||
// Initialize service manager with proper configuration
|
||||
s, p, err := sc.initializeServiceManagerWithServiceConfig(svcConfig)
|
||||
if err != nil {
|
||||
logger.Error().Err(err).Msg("Failed to initialize service manager")
|
||||
return err
|
||||
}
|
||||
|
||||
p.cfg = &cfg
|
||||
p.preRun()
|
||||
|
||||
status, err := s.Status()
|
||||
isCtrldRunning := status == service.StatusRunning
|
||||
isCtrldInstalled := !errors.Is(err, service.ErrNotInstalled)
|
||||
|
||||
// Get current running iface, if any.
|
||||
var currentIface *ifaceResponse
|
||||
|
||||
// If pin code was set, do not allow running start command.
|
||||
if isCtrldRunning {
|
||||
if err := checkDeactivationPin(s, nil); isCheckDeactivationPinErr(err) {
|
||||
logger.Error().Msg("Deactivation pin check failed")
|
||||
os.Exit(deactivationPinInvalidExitCode)
|
||||
}
|
||||
currentIface = runningIface(s)
|
||||
logger.Debug().Msgf("Current interface on start: %v", currentIface)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
reportSetDnsOk := func(sockDir string) {
|
||||
if cc := newSocketControlClient(ctx, s, sockDir); cc != nil {
|
||||
if resp, _ := cc.post(ifacePath, nil); resp != nil && resp.StatusCode == http.StatusOK {
|
||||
if iface == "auto" {
|
||||
iface = defaultIfaceName()
|
||||
}
|
||||
res := &ifaceResponse{}
|
||||
if err := json.NewDecoder(resp.Body).Decode(res); err != nil {
|
||||
logger.Warn().Err(err).Msg("Failed to get iface info")
|
||||
return
|
||||
}
|
||||
if res.OK {
|
||||
name := res.Name
|
||||
if iff, err := net.InterfaceByName(name); err == nil {
|
||||
_, _ = patchNetIfaceName(iff)
|
||||
name = iff.Name
|
||||
}
|
||||
logger := logger.With().Str("iface", name)
|
||||
logger.Debug().Msg("Setting DNS successfully")
|
||||
if res.All {
|
||||
// Log that DNS is set for other interfaces.
|
||||
withEachPhysicalInterfaces(
|
||||
name,
|
||||
"set DNS",
|
||||
func(i *net.Interface) error { return nil },
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// No config path, generating config in HOME directory.
|
||||
noConfigStart := isNoConfigStart(cmd)
|
||||
writeDefaultConfig := !noConfigStart && configBase64 == ""
|
||||
|
||||
logServerStarted := make(chan struct{})
|
||||
stopLogCh := make(chan struct{})
|
||||
ud, err := userHomeDir()
|
||||
sockDir := ud
|
||||
var logServerSocketPath string
|
||||
if err != nil {
|
||||
logger.Warn().Err(err).Msg("Failed to get user home directory")
|
||||
logger.Warn().Msg("Log server did not start")
|
||||
close(logServerStarted)
|
||||
} else {
|
||||
setWorkingDirectory(svcConfig, ud)
|
||||
if configPath == "" && writeDefaultConfig {
|
||||
defaultConfigFile = filepath.Join(ud, defaultConfigFile)
|
||||
}
|
||||
svcConfig.Arguments = append(svcConfig.Arguments, "--homedir="+ud)
|
||||
if d, err := socketDir(); err == nil {
|
||||
sockDir = d
|
||||
}
|
||||
logServerSocketPath = filepath.Join(sockDir, ctrldLogUnixSock)
|
||||
_ = os.Remove(logServerSocketPath)
|
||||
go func() {
|
||||
defer os.Remove(logServerSocketPath)
|
||||
|
||||
close(logServerStarted)
|
||||
|
||||
// Start HTTP log server
|
||||
if err := httpLogServer(logServerSocketPath, stopLogCh); err != nil && err != http.ErrServerClosed {
|
||||
logger.Warn().Err(err).Msg("Failed to serve HTTP log server")
|
||||
return
|
||||
}
|
||||
}()
|
||||
}
|
||||
<-logServerStarted
|
||||
|
||||
if !startOnly {
|
||||
startOnly = len(osArgs) == 0
|
||||
}
|
||||
// If user run "ctrld start" and ctrld is already installed, starting existing service.
|
||||
if startOnly && isCtrldInstalled {
|
||||
tryReadingConfigWithNotice(false, true)
|
||||
if err := v.Unmarshal(&cfg); err != nil {
|
||||
logger.Fatal().Msgf("Failed to unmarshal config: %v", err)
|
||||
}
|
||||
|
||||
// if already running, dont restart
|
||||
if isCtrldRunning {
|
||||
logger.Notice().Msg("Service is already running")
|
||||
return nil
|
||||
}
|
||||
|
||||
initInteractiveLogging()
|
||||
tasks := []task{
|
||||
{func() error {
|
||||
// Save current DNS so we can restore later.
|
||||
withEachPhysicalInterfaces("", "saveCurrentStaticDNS", func(i *net.Interface) error {
|
||||
if err := saveCurrentStaticDNS(i); !errors.Is(err, errSaveCurrentStaticDNSNotSupported) && err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
return nil
|
||||
}, false, "Save current DNS"},
|
||||
{func() error {
|
||||
return ConfigureWindowsServiceFailureActions(ctrldServiceName)
|
||||
}, false, "Configure service failure actions"},
|
||||
{s.Start, true, "Start"},
|
||||
{noticeWritingControlDConfig, false, "Notice writing ControlD config"},
|
||||
}
|
||||
logger.Notice().Msg("Starting existing ctrld service")
|
||||
if doTasks(tasks) {
|
||||
logger.Notice().Msg("Service started")
|
||||
sockDir, err := socketDir()
|
||||
if err != nil {
|
||||
logger.Warn().Err(err).Msg("Failed to get socket directory")
|
||||
os.Exit(1)
|
||||
}
|
||||
reportSetDnsOk(sockDir)
|
||||
} else {
|
||||
logger.Error().Err(err).Msg("Failed to start existing ctrld service")
|
||||
os.Exit(1)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
if cdUID != "" {
|
||||
_ = doValidateCdRemoteConfig(cdUID, true)
|
||||
} else if uid := cdUIDFromProvToken(); uid != "" {
|
||||
cdUID = uid
|
||||
logger.Debug().Msg("Using uid from provision token")
|
||||
removeOrgFlagsFromArgs(svcConfig)
|
||||
// Pass --cd flag to "ctrld run" command, so the provision token takes no effect.
|
||||
svcConfig.Arguments = append(svcConfig.Arguments, "--cd="+cdUID)
|
||||
}
|
||||
if cdUID != "" {
|
||||
validateCdUpstreamProtocol()
|
||||
}
|
||||
|
||||
if configPath != "" {
|
||||
v.SetConfigFile(configPath)
|
||||
}
|
||||
|
||||
tryReadingConfigWithNotice(writeDefaultConfig, true)
|
||||
|
||||
if err := v.Unmarshal(&cfg); err != nil {
|
||||
logger.Fatal().Msgf("Failed to unmarshal config: %v", err)
|
||||
}
|
||||
|
||||
initInteractiveLogging()
|
||||
|
||||
if nextdns != "" {
|
||||
removeNextDNSFromArgs(svcConfig)
|
||||
}
|
||||
|
||||
// Explicitly passing config, so on system where home directory could not be obtained,
|
||||
// or sub-process env is different with the parent, we still behave correctly and use
|
||||
// the expected config file.
|
||||
if configPath == "" {
|
||||
svcConfig.Arguments = append(svcConfig.Arguments, "--config="+defaultConfigFile)
|
||||
}
|
||||
|
||||
tasks := []task{
|
||||
{s.Stop, false, "Stop"},
|
||||
{func() error { return doGenerateNextDNSConfig(nextdns) }, true, "Checking config"},
|
||||
{func() error { return ensureUninstall(s) }, false, "Ensure uninstall"},
|
||||
//resetDnsTask(p, s, isCtrldInstalled, currentIface),
|
||||
{func() error {
|
||||
// Save current DNS so we can restore later.
|
||||
withEachPhysicalInterfaces("", "saveCurrentStaticDNS", func(i *net.Interface) error {
|
||||
if err := saveCurrentStaticDNS(i); !errors.Is(err, errSaveCurrentStaticDNSNotSupported) && err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
return nil
|
||||
}, false, "Save current DNS"},
|
||||
{s.Install, false, "Install"},
|
||||
{func() error {
|
||||
return ConfigureWindowsServiceFailureActions(ctrldServiceName)
|
||||
}, false, "Configure Windows service failure actions"},
|
||||
{s.Start, true, "Start"},
|
||||
// Note that startCmd do not actually write ControlD config, but the config file was
|
||||
// generated after s.Start, so we notice users here for consistent with nextdns mode.
|
||||
{noticeWritingControlDConfig, false, "Notice writing ControlD config"},
|
||||
}
|
||||
logger.Notice().Msg("Starting service")
|
||||
if doTasks(tasks) {
|
||||
// add a small delay to ensure the service is started and did not crash
|
||||
time.Sleep(1 * time.Second)
|
||||
|
||||
ok, status, err := selfCheckStatus(ctx, s, sockDir)
|
||||
switch {
|
||||
case ok && status == service.StatusRunning:
|
||||
logger.Notice().Msg("Service started")
|
||||
default:
|
||||
marker := append(bytes.Repeat([]byte("="), 32), '\n')
|
||||
// If ctrld service is not running, emitting log obtained from ctrld process.
|
||||
if status != service.StatusRunning || ctx.Err() != nil {
|
||||
logger.Error().Msg("Ctrld service may not have started due to an error or misconfiguration, service log:")
|
||||
_, _ = logger.Write(marker)
|
||||
|
||||
// Wait for log collection to complete
|
||||
<-stopLogCh
|
||||
|
||||
// Retrieve logs from HTTP server if available
|
||||
if logServerSocketPath != "" {
|
||||
hlc := newHTTPLogClient(logServerSocketPath)
|
||||
logs, err := hlc.GetLogs()
|
||||
if err != nil {
|
||||
logger.Warn().Err(err).Msg("Failed to get logs from HTTP log server")
|
||||
}
|
||||
if len(logs) == 0 {
|
||||
logger.Write([]byte("<no log output is obtained from ctrld process>\n"))
|
||||
} else {
|
||||
logger.Write(logs)
|
||||
logger.Write([]byte("\n"))
|
||||
}
|
||||
} else {
|
||||
logger.Write([]byte("<no log output from HTTP log server>\n"))
|
||||
}
|
||||
}
|
||||
// Report any error if occurred.
|
||||
if err != nil {
|
||||
_, _ = logger.Write(marker)
|
||||
msg := fmt.Sprintf("An error occurred while performing test query: %s\n", err)
|
||||
logger.Write([]byte(msg))
|
||||
}
|
||||
// If ctrld service is running but selfCheckStatus failed, it could be related
|
||||
// to user's system firewall configuration, notice users about it.
|
||||
if status == service.StatusRunning && err == nil {
|
||||
_, _ = logger.Write(marker)
|
||||
logger.Write([]byte("ctrld service was running, but a DNS query could not be sent to its listener\n"))
|
||||
logger.Write([]byte("Please check your system firewall if it is configured to block/intercept/redirect DNS queries\n"))
|
||||
}
|
||||
|
||||
_, _ = logger.Write(marker)
|
||||
uninstall(p, s)
|
||||
os.Exit(1)
|
||||
}
|
||||
reportSetDnsOk(sockDir)
|
||||
}
|
||||
|
||||
logger.Debug().Msg("Service start command completed")
|
||||
return nil
|
||||
}
|
||||
|
||||
// createStartCommands creates the start command and its alias
|
||||
func createStartCommands(sc *ServiceCommand) (*cobra.Command, *cobra.Command) {
|
||||
// Start command
|
||||
startCmd := &cobra.Command{
|
||||
Use: "start",
|
||||
Short: "Install and start the ctrld service",
|
||||
Long: `Install and start the ctrld service
|
||||
|
||||
NOTE: running "ctrld start" without any arguments will start already installed ctrld service.`,
|
||||
Args: func(cmd *cobra.Command, args []string) error {
|
||||
args = filterEmptyStrings(args)
|
||||
if len(args) > 0 {
|
||||
return fmt.Errorf("'ctrld start' doesn't accept positional arguments\n" +
|
||||
"Use flags instead (e.g. --cd, --iface) or see 'ctrld start --help' for all options")
|
||||
}
|
||||
return nil
|
||||
},
|
||||
PreRun: func(cmd *cobra.Command, args []string) {
|
||||
checkHasElevatedPrivilege()
|
||||
},
|
||||
RunE: sc.Start,
|
||||
}
|
||||
// Keep these flags in sync with runCmd above, except for "-d"/"--nextdns".
|
||||
startCmd.Flags().StringVarP(&configPath, "config", "c", "", "Path to config file")
|
||||
startCmd.Flags().StringVarP(&configBase64, "base64_config", "", "", "Base64 encoded config")
|
||||
startCmd.Flags().StringVarP(&listenAddress, "listen", "", "", "Listener address and port, in format: address:port")
|
||||
startCmd.Flags().StringVarP(&primaryUpstream, "primary_upstream", "", "", "Primary upstream endpoint")
|
||||
startCmd.Flags().StringVarP(&secondaryUpstream, "secondary_upstream", "", "", "Secondary upstream endpoint")
|
||||
startCmd.Flags().StringSliceVarP(&domains, "domains", "", nil, "List of domain to apply in a split DNS policy")
|
||||
startCmd.Flags().StringVarP(&logPath, "log", "", "", "Path to log file")
|
||||
startCmd.Flags().IntVarP(&cacheSize, "cache_size", "", 0, "Enable cache with size items")
|
||||
startCmd.Flags().StringVarP(&cdUID, cdUidFlagName, "", "", "Control D resolver uid")
|
||||
startCmd.Flags().StringVarP(&cdOrg, cdOrgFlagName, "", "", "Control D provision token")
|
||||
startCmd.Flags().StringVarP(&customHostname, customHostnameFlagName, "", "", "Custom hostname passed to ControlD API")
|
||||
startCmd.Flags().BoolVarP(&cdDev, "dev", "", false, "Use Control D dev resolver/domain")
|
||||
_ = startCmd.Flags().MarkHidden("dev")
|
||||
startCmd.Flags().StringVarP(&iface, "iface", "", "", `Update DNS setting for iface, "auto" means the default interface gateway`)
|
||||
startCmd.Flags().StringVarP(&nextdns, nextdnsFlagName, "", "", "NextDNS resolver id")
|
||||
startCmd.Flags().StringVarP(&cdUpstreamProto, "proto", "", ctrld.ResolverTypeDOH, `Control D upstream type, either "doh" or "doh3"`)
|
||||
startCmd.Flags().BoolVarP(&skipSelfChecks, "skip_self_checks", "", false, `Skip self checks after installing ctrld service`)
|
||||
startCmd.Flags().BoolVarP(&startOnly, "start_only", "", false, "Do not install new service")
|
||||
_ = startCmd.Flags().MarkHidden("start_only")
|
||||
startCmd.Flags().BoolVarP(&rfc1918, "rfc1918", "", false, "Listen on RFC1918 addresses when 127.0.0.1 is the only listener")
|
||||
|
||||
// Start command alias
|
||||
startCmdAlias := &cobra.Command{
|
||||
PreRun: func(cmd *cobra.Command, args []string) {
|
||||
checkHasElevatedPrivilege()
|
||||
},
|
||||
Use: "start",
|
||||
Short: "Quick start service and configure DNS on interface",
|
||||
Long: `Quick start service and configure DNS on interface
|
||||
|
||||
NOTE: running "ctrld start" without any arguments will start already installed ctrld service.`,
|
||||
Args: func(cmd *cobra.Command, args []string) error {
|
||||
args = filterEmptyStrings(args)
|
||||
if len(args) > 0 {
|
||||
return fmt.Errorf("'ctrld start' doesn't accept positional arguments\n" +
|
||||
"Use flags instead (e.g. --cd, --iface) or see 'ctrld start --help' for all options")
|
||||
}
|
||||
return nil
|
||||
},
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
if len(os.Args) == 2 {
|
||||
startOnly = true
|
||||
}
|
||||
if !cmd.Flags().Changed("iface") {
|
||||
os.Args = append(os.Args, "--iface="+ifaceStartStop)
|
||||
}
|
||||
iface = ifaceStartStop
|
||||
return startCmd.RunE(cmd, args)
|
||||
},
|
||||
}
|
||||
startCmdAlias.Flags().StringVarP(&ifaceStartStop, "iface", "", "auto", `Update DNS setting for iface, "auto" means the default interface gateway`)
|
||||
startCmdAlias.Flags().AddFlagSet(startCmd.Flags())
|
||||
|
||||
return startCmd, startCmdAlias
|
||||
}
|
||||
41
cmd/cli/commands_service_status.go
Normal file
41
cmd/cli/commands_service_status.go
Normal file
@@ -0,0 +1,41 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"os"
|
||||
|
||||
"github.com/kardianos/service"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
// Status implements the logic from cmdStatus.Run
|
||||
func (sc *ServiceCommand) Status(cmd *cobra.Command, args []string) error {
|
||||
logger := mainLog.Load()
|
||||
logger.Debug().Msg("Service status command started")
|
||||
|
||||
s, _, err := sc.initializeServiceManager()
|
||||
if err != nil {
|
||||
logger.Error().Err(err).Msg("Failed to initialize service manager")
|
||||
return err
|
||||
}
|
||||
|
||||
status, err := s.Status()
|
||||
if err != nil {
|
||||
logger.Error().Msg(err.Error())
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
switch status {
|
||||
case service.StatusUnknown:
|
||||
logger.Notice().Msg("Unknown status")
|
||||
os.Exit(2)
|
||||
case service.StatusRunning:
|
||||
logger.Notice().Msg("Service is running")
|
||||
os.Exit(0)
|
||||
case service.StatusStopped:
|
||||
logger.Notice().Msg("Service is stopped")
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
logger.Debug().Msg("Service status command completed")
|
||||
return nil
|
||||
}
|
||||
61
cmd/cli/commands_service_stop.go
Normal file
61
cmd/cli/commands_service_stop.go
Normal file
@@ -0,0 +1,61 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"os"
|
||||
|
||||
"github.com/kardianos/service"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
// Stop implements the logic from cmdStop.Run
|
||||
func (sc *ServiceCommand) Stop(cmd *cobra.Command, args []string) error {
|
||||
logger := mainLog.Load()
|
||||
logger.Debug().Msg("Service stop command started")
|
||||
|
||||
readConfig(false)
|
||||
v.Unmarshal(&cfg)
|
||||
|
||||
s, p, err := sc.initializeServiceManager()
|
||||
if err != nil {
|
||||
logger.Error().Err(err).Msg("Failed to initialize service manager")
|
||||
return err
|
||||
}
|
||||
|
||||
p.cfg = &cfg
|
||||
if iface == "" {
|
||||
iface = "auto"
|
||||
}
|
||||
p.preRun()
|
||||
if ir := runningIface(s); ir != nil {
|
||||
p.runningIface = ir.Name
|
||||
p.requiredMultiNICsConfig = ir.All
|
||||
}
|
||||
|
||||
initInteractiveLogging()
|
||||
|
||||
status, err := s.Status()
|
||||
if errors.Is(err, service.ErrNotInstalled) {
|
||||
logger.Warn().Msg("Service not installed")
|
||||
return nil
|
||||
}
|
||||
if status == service.StatusStopped {
|
||||
logger.Warn().Msg("Service is already stopped")
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := checkDeactivationPin(s, nil); isCheckDeactivationPinErr(err) {
|
||||
logger.Error().Msg("Deactivation pin check failed")
|
||||
os.Exit(deactivationPinInvalidExitCode)
|
||||
}
|
||||
|
||||
logger.Debug().Msg("Stopping service")
|
||||
if doTasks([]task{{s.Stop, true, "Stop"}}) {
|
||||
logger.Notice().Msg("Service stopped")
|
||||
} else {
|
||||
logger.Error().Msg("Service stop failed")
|
||||
}
|
||||
|
||||
logger.Debug().Msg("Service stop command completed")
|
||||
return nil
|
||||
}
|
||||
106
cmd/cli/commands_service_uninstall.go
Normal file
106
cmd/cli/commands_service_uninstall.go
Normal file
@@ -0,0 +1,106 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"net"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld"
|
||||
)
|
||||
|
||||
// Uninstall implements the logic from cmdUninstall.Run
|
||||
func (sc *ServiceCommand) Uninstall(cmd *cobra.Command, args []string) error {
|
||||
logger := mainLog.Load()
|
||||
logger.Debug().Msg("Service uninstall command started")
|
||||
|
||||
readConfig(false)
|
||||
v.Unmarshal(&cfg)
|
||||
|
||||
s, p, err := sc.initializeServiceManager()
|
||||
if err != nil {
|
||||
logger.Error().Err(err).Msg("Failed to initialize service manager")
|
||||
return err
|
||||
}
|
||||
|
||||
p.cfg = &cfg
|
||||
if iface == "" {
|
||||
iface = "auto"
|
||||
}
|
||||
p.preRun()
|
||||
if ir := runningIface(s); ir != nil {
|
||||
p.runningIface = ir.Name
|
||||
p.requiredMultiNICsConfig = ir.All
|
||||
}
|
||||
|
||||
if err := checkDeactivationPin(s, nil); isCheckDeactivationPinErr(err) {
|
||||
logger.Error().Msg("Deactivation pin check failed")
|
||||
os.Exit(deactivationPinInvalidExitCode)
|
||||
}
|
||||
|
||||
logger.Debug().Msg("Starting service uninstall")
|
||||
uninstall(p, s)
|
||||
|
||||
if cleanup {
|
||||
logger.Debug().Msg("Performing cleanup operations")
|
||||
var files []string
|
||||
// Config file.
|
||||
files = append(files, v.ConfigFileUsed())
|
||||
// Log file and backup log file.
|
||||
// For safety, only process if log file path is absolute.
|
||||
if logFile := normalizeLogFilePath(cfg.Service.LogPath); filepath.IsAbs(logFile) {
|
||||
files = append(files, logFile)
|
||||
oldLogFile := logFile + oldLogSuffix
|
||||
if _, err := os.Stat(oldLogFile); err == nil {
|
||||
files = append(files, oldLogFile)
|
||||
}
|
||||
}
|
||||
// Socket files.
|
||||
if dir, _ := socketDir(); dir != "" {
|
||||
files = append(files, filepath.Join(dir, ctrldControlUnixSock))
|
||||
files = append(files, filepath.Join(dir, ctrldLogUnixSock))
|
||||
}
|
||||
// Static DNS settings files.
|
||||
withEachPhysicalInterfaces("", "", func(i *net.Interface) error {
|
||||
file := ctrld.SavedStaticDnsSettingsFilePath(i)
|
||||
files = append(files, file)
|
||||
return nil
|
||||
})
|
||||
bin, err := os.Executable()
|
||||
if err != nil {
|
||||
logger.Warn().Err(err).Msg("Failed to get executable path")
|
||||
}
|
||||
if bin != "" && supportedSelfDelete {
|
||||
files = append(files, bin)
|
||||
}
|
||||
// Backup file after upgrading.
|
||||
oldBin := bin + oldBinSuffix
|
||||
if _, err := os.Stat(oldBin); err == nil {
|
||||
files = append(files, oldBin)
|
||||
}
|
||||
for _, file := range files {
|
||||
if file == "" {
|
||||
continue
|
||||
}
|
||||
if err := os.Remove(file); err == nil {
|
||||
logger.Notice().Str("file", file).Msg("File removed during cleanup")
|
||||
} else {
|
||||
logger.Debug().Err(err).Str("file", file).Msg("Failed to remove file during cleanup")
|
||||
}
|
||||
}
|
||||
// Self-delete the ctrld binary if supported
|
||||
if err := selfDeleteExe(); err != nil {
|
||||
logger.Warn().Err(err).Msg("Failed to delete ctrld binary")
|
||||
} else {
|
||||
if !supportedSelfDelete {
|
||||
logger.Debug().Msgf("File removed: %s", bin)
|
||||
}
|
||||
}
|
||||
|
||||
logger.Debug().Msg("Cleanup operations completed")
|
||||
}
|
||||
|
||||
logger.Debug().Msg("Service uninstall command completed")
|
||||
return nil
|
||||
}
|
||||
197
cmd/cli/commands_test.go
Normal file
197
cmd/cli/commands_test.go
Normal file
@@ -0,0 +1,197 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"testing"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestBasicCommandStructure tests the actual root command structure
|
||||
func TestBasicCommandStructure(t *testing.T) {
|
||||
// Test the actual root command that's returned from initCLI()
|
||||
rootCmd := initCLI()
|
||||
|
||||
// Test that root command has basic properties
|
||||
assert.Equal(t, "ctrld", rootCmd.Use)
|
||||
assert.NotEmpty(t, rootCmd.Short, "Root command should have a short description")
|
||||
|
||||
// Test that root command has subcommands
|
||||
commands := rootCmd.Commands()
|
||||
assert.NotNil(t, commands, "Root command should have subcommands")
|
||||
assert.Greater(t, len(commands), 0, "Root command should have at least one subcommand")
|
||||
|
||||
// Test that expected commands exist
|
||||
expectedCommands := []string{"run", "service", "clients", "upgrade", "log"}
|
||||
for _, cmdName := range expectedCommands {
|
||||
found := false
|
||||
for _, cmd := range commands {
|
||||
if cmd.Name() == cmdName {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
assert.True(t, found, "Expected command %s not found in root command", cmdName)
|
||||
}
|
||||
}
|
||||
|
||||
// TestServiceCommandCreation tests service command creation
|
||||
func TestServiceCommandCreation(t *testing.T) {
|
||||
sc := NewServiceCommand()
|
||||
require.NotNil(t, sc, "ServiceCommand should be created")
|
||||
|
||||
// Test service config creation
|
||||
config := sc.createServiceConfig()
|
||||
require.NotNil(t, config, "Service config should be created")
|
||||
assert.Equal(t, ctrldServiceName, config.Name)
|
||||
assert.Equal(t, "Control-D Helper Service", config.DisplayName)
|
||||
assert.Equal(t, "A highly configurable, multi-protocol DNS forwarding proxy", config.Description)
|
||||
}
|
||||
|
||||
// TestServiceCommandSubCommands tests service command sub commands
|
||||
func TestServiceCommandSubCommands(t *testing.T) {
|
||||
rootCmd := &cobra.Command{
|
||||
Use: "ctrld",
|
||||
Short: "DNS forwarding proxy",
|
||||
}
|
||||
|
||||
serviceCmd := InitServiceCmd(rootCmd)
|
||||
require.NotNil(t, serviceCmd, "Service command should be created")
|
||||
|
||||
// Test that service command has subcommands
|
||||
subcommands := serviceCmd.Commands()
|
||||
assert.Greater(t, len(subcommands), 0, "Service command should have subcommands")
|
||||
|
||||
// Test specific subcommands exist
|
||||
expectedCommands := []string{"start", "stop", "restart", "reload", "status", "uninstall", "interfaces"}
|
||||
|
||||
for _, cmdName := range expectedCommands {
|
||||
found := false
|
||||
for _, cmd := range subcommands {
|
||||
if cmd.Name() == cmdName {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
assert.True(t, found, "Expected service subcommand %s not found", cmdName)
|
||||
}
|
||||
}
|
||||
|
||||
// TestCommandHelp tests basic help functionality
|
||||
func TestCommandHelp(t *testing.T) {
|
||||
// Initialize the CLI to set up the root command
|
||||
rootCmd := initCLI()
|
||||
|
||||
// Test help command execution
|
||||
var buf bytes.Buffer
|
||||
rootCmd.SetOut(&buf)
|
||||
rootCmd.SetErr(&buf)
|
||||
|
||||
rootCmd.SetArgs([]string{"--help"})
|
||||
err := rootCmd.Execute()
|
||||
assert.NoError(t, err, "Help command should execute without error")
|
||||
assert.Contains(t, buf.String(), "dns forwarding proxy", "Help output should contain description")
|
||||
}
|
||||
|
||||
// TestCommandVersion tests version command
|
||||
func TestCommandVersion(t *testing.T) {
|
||||
// Initialize the CLI to set up the root command
|
||||
rootCmd := initCLI()
|
||||
|
||||
var buf bytes.Buffer
|
||||
rootCmd.SetOut(&buf)
|
||||
rootCmd.SetErr(&buf)
|
||||
|
||||
// Test version command
|
||||
rootCmd.SetArgs([]string{"--version"})
|
||||
err := rootCmd.Execute()
|
||||
assert.NoError(t, err, "Version command should execute without error")
|
||||
assert.Contains(t, buf.String(), "version", "Version output should contain version information")
|
||||
}
|
||||
|
||||
// TestCommandErrorHandling tests error handling
|
||||
func TestCommandErrorHandling(t *testing.T) {
|
||||
// Initialize the CLI to set up the root command
|
||||
rootCmd := initCLI()
|
||||
|
||||
// Test invalid flag instead of invalid command
|
||||
rootCmd.SetArgs([]string{"--invalid-flag"})
|
||||
err := rootCmd.Execute()
|
||||
assert.Error(t, err, "Invalid flag should return error")
|
||||
}
|
||||
|
||||
// TestCommandFlags tests flag functionality
|
||||
func TestCommandFlags(t *testing.T) {
|
||||
// Initialize the CLI to set up the root command
|
||||
rootCmd := initCLI()
|
||||
|
||||
// Test that root command has expected flags
|
||||
verboseFlag := rootCmd.PersistentFlags().Lookup("verbose")
|
||||
assert.NotNil(t, verboseFlag, "Verbose flag should exist")
|
||||
assert.Equal(t, "v", verboseFlag.Shorthand)
|
||||
|
||||
silentFlag := rootCmd.PersistentFlags().Lookup("silent")
|
||||
assert.NotNil(t, silentFlag, "Silent flag should exist")
|
||||
assert.Equal(t, "s", silentFlag.Shorthand)
|
||||
}
|
||||
|
||||
// TestCommandExecution tests basic command execution
|
||||
func TestCommandExecution(t *testing.T) {
|
||||
// Initialize the CLI to set up the root command
|
||||
rootCmd := initCLI()
|
||||
|
||||
// Test that root command can be executed (help command)
|
||||
var buf bytes.Buffer
|
||||
rootCmd.SetOut(&buf)
|
||||
rootCmd.SetErr(&buf)
|
||||
|
||||
rootCmd.SetArgs([]string{"--help"})
|
||||
err := rootCmd.Execute()
|
||||
assert.NoError(t, err, "Root command should execute without error")
|
||||
assert.Contains(t, buf.String(), "dns forwarding proxy", "Help output should contain description")
|
||||
}
|
||||
|
||||
// TestCommandArgs tests argument handling
|
||||
func TestCommandArgs(t *testing.T) {
|
||||
// Initialize the CLI to set up the root command
|
||||
rootCmd := initCLI()
|
||||
|
||||
// Test that root command can handle arguments properly
|
||||
// Test with no args (should succeed)
|
||||
err := rootCmd.Execute()
|
||||
assert.NoError(t, err, "Root command with no args should execute")
|
||||
|
||||
// Test with help flag (should succeed)
|
||||
rootCmd.SetArgs([]string{"--help"})
|
||||
err = rootCmd.Execute()
|
||||
assert.NoError(t, err, "Root command with help flag should execute")
|
||||
}
|
||||
|
||||
// TestCommandSubcommands tests subcommand functionality
|
||||
func TestCommandSubcommands(t *testing.T) {
|
||||
// Initialize the CLI to set up the root command
|
||||
rootCmd := initCLI()
|
||||
|
||||
// Test that root command has subcommands
|
||||
commands := rootCmd.Commands()
|
||||
assert.Greater(t, len(commands), 0, "Root command should have subcommands")
|
||||
|
||||
// Test that specific subcommands exist and can be executed
|
||||
expectedSubcommands := []string{"run", "service", "clients", "upgrade", "log"}
|
||||
for _, subCmdName := range expectedSubcommands {
|
||||
// Find the subcommand
|
||||
var subCmd *cobra.Command
|
||||
for _, cmd := range commands {
|
||||
if cmd.Name() == subCmdName {
|
||||
subCmd = cmd
|
||||
break
|
||||
}
|
||||
}
|
||||
assert.NotNil(t, subCmd, "Subcommand %s should exist", subCmdName)
|
||||
|
||||
// Test that subcommand has help
|
||||
assert.NotEmpty(t, subCmd.Short, "Subcommand %s should have a short description", subCmdName)
|
||||
}
|
||||
}
|
||||
192
cmd/cli/commands_upgrade.go
Normal file
192
cmd/cli/commands_upgrade.go
Normal file
@@ -0,0 +1,192 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/kardianos/service"
|
||||
"github.com/minio/selfupdate"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
const (
|
||||
upgradeChannelDev = "dev"
|
||||
upgradeChannelProd = "prod"
|
||||
upgradeChannelDefault = "default"
|
||||
)
|
||||
|
||||
// UpgradeCommand handles upgrade-related operations
|
||||
type UpgradeCommand struct {
|
||||
}
|
||||
|
||||
// NewUpgradeCommand creates a new upgrade command handler
|
||||
func NewUpgradeCommand() (*UpgradeCommand, error) {
|
||||
return &UpgradeCommand{}, nil
|
||||
}
|
||||
|
||||
// Upgrade performs the upgrade operation
|
||||
func (uc *UpgradeCommand) Upgrade(cmd *cobra.Command, args []string) error {
|
||||
upgradeChannel := map[string]string{
|
||||
upgradeChannelDefault: "https://dl.controld.dev",
|
||||
upgradeChannelDev: "https://dl.controld.dev",
|
||||
upgradeChannelProd: "https://dl.controld.com",
|
||||
}
|
||||
if isStableVersion(curVersion()) {
|
||||
upgradeChannel[upgradeChannelDefault] = upgradeChannel[upgradeChannelProd]
|
||||
}
|
||||
|
||||
bin, err := os.Executable()
|
||||
if err != nil {
|
||||
mainLog.Load().Fatal().Err(err).Msg("Failed to get current ctrld binary path")
|
||||
}
|
||||
|
||||
readConfig(false)
|
||||
v.Unmarshal(&cfg)
|
||||
svcCmd := NewServiceCommand()
|
||||
s, p, err := svcCmd.initializeServiceManager()
|
||||
if err != nil {
|
||||
mainLog.Load().Error().Msg(err.Error())
|
||||
return nil
|
||||
}
|
||||
|
||||
if iface == "" {
|
||||
iface = "auto"
|
||||
}
|
||||
p.preRun()
|
||||
if ir := runningIface(s); ir != nil {
|
||||
p.runningIface = ir.Name
|
||||
p.requiredMultiNICsConfig = ir.All
|
||||
}
|
||||
|
||||
svcInstalled := true
|
||||
if _, err := s.Status(); errors.Is(err, service.ErrNotInstalled) {
|
||||
svcInstalled = false
|
||||
}
|
||||
|
||||
oldBin := bin + oldBinSuffix
|
||||
baseUrl := upgradeChannel[upgradeChannelDefault]
|
||||
if len(args) > 0 {
|
||||
channel := args[0]
|
||||
switch channel {
|
||||
case upgradeChannelProd, upgradeChannelDev: // ok
|
||||
default:
|
||||
mainLog.Load().Fatal().Msgf("Upgrade argument must be either %q or %q", upgradeChannelProd, upgradeChannelDev)
|
||||
}
|
||||
baseUrl = upgradeChannel[channel]
|
||||
}
|
||||
|
||||
dlUrl := upgradeUrl(baseUrl)
|
||||
mainLog.Load().Debug().Msgf("Downloading binary: %s", dlUrl)
|
||||
|
||||
resp, err := getWithRetry(dlUrl, downloadServerIp)
|
||||
if err != nil {
|
||||
mainLog.Load().Fatal().Err(err).Msg("Failed to download binary")
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
mainLog.Load().Fatal().Msgf("Could not download binary: %s", http.StatusText(resp.StatusCode))
|
||||
}
|
||||
|
||||
mainLog.Load().Debug().Msg("Updating current binary")
|
||||
if err := selfupdate.Apply(resp.Body, selfupdate.Options{OldSavePath: oldBin}); err != nil {
|
||||
if rerr := selfupdate.RollbackError(err); rerr != nil {
|
||||
mainLog.Load().Error().Err(rerr).Msg("Could not rollback old binary")
|
||||
}
|
||||
mainLog.Load().Fatal().Err(err).Msg("Failed to update current binary")
|
||||
}
|
||||
|
||||
doRestart := func() bool {
|
||||
if !svcInstalled {
|
||||
return true
|
||||
}
|
||||
tasks := []task{
|
||||
{s.Stop, true, "Stop"},
|
||||
{func() error {
|
||||
// restore static DNS settings or DHCP
|
||||
p.resetDNS(false, true)
|
||||
return nil
|
||||
}, false, "Cleanup"},
|
||||
{func() error {
|
||||
time.Sleep(time.Second * 1)
|
||||
return nil
|
||||
}, false, "Waiting for service to stop"},
|
||||
}
|
||||
doTasks(tasks)
|
||||
|
||||
tasks = []task{
|
||||
{s.Start, true, "Start"},
|
||||
}
|
||||
if doTasks(tasks) {
|
||||
if dir, err := socketDir(); err == nil {
|
||||
if cc := newSocketControlClient(context.TODO(), s, dir); cc != nil {
|
||||
_, _ = cc.post(ifacePath, nil)
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
if svcInstalled {
|
||||
mainLog.Load().Debug().Msg("Restarting ctrld service using new binary")
|
||||
}
|
||||
|
||||
if doRestart() {
|
||||
_ = os.Remove(oldBin)
|
||||
_ = os.Chmod(bin, 0755)
|
||||
ver := "unknown version"
|
||||
out, err := exec.Command(bin, "--version").CombinedOutput()
|
||||
if err != nil {
|
||||
mainLog.Load().Warn().Err(err).Msg("Failed to get new binary version")
|
||||
}
|
||||
if after, found := strings.CutPrefix(string(out), "ctrld version "); found {
|
||||
ver = after
|
||||
}
|
||||
mainLog.Load().Notice().Msgf("Upgrade successful - %s", ver)
|
||||
return nil
|
||||
}
|
||||
|
||||
mainLog.Load().Warn().Msgf("Upgrade failed, restoring previous binary: %s", oldBin)
|
||||
if err := os.Remove(bin); err != nil {
|
||||
mainLog.Load().Fatal().Err(err).Msg("Failed to remove new binary")
|
||||
}
|
||||
if err := os.Rename(oldBin, bin); err != nil {
|
||||
mainLog.Load().Fatal().Err(err).Msg("Failed to restore old binary")
|
||||
}
|
||||
if doRestart() {
|
||||
mainLog.Load().Notice().Msg("Restored previous binary successfully")
|
||||
return nil
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// InitUpgradeCmd creates the upgrade command with proper logic
|
||||
func InitUpgradeCmd(rootCmd *cobra.Command) *cobra.Command {
|
||||
upgradeCmd := &cobra.Command{
|
||||
Use: "upgrade",
|
||||
Short: "Upgrading ctrld to latest version",
|
||||
ValidArgs: []string{upgradeChannelDev, upgradeChannelProd},
|
||||
Args: cobra.MaximumNArgs(1),
|
||||
PreRun: func(cmd *cobra.Command, args []string) {
|
||||
checkHasElevatedPrivilege()
|
||||
},
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
uc, err := NewUpgradeCommand()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return uc.Upgrade(cmd, args)
|
||||
},
|
||||
}
|
||||
|
||||
rootCmd.AddCommand(upgradeCmd)
|
||||
|
||||
return upgradeCmd
|
||||
}
|
||||
40
cmd/cli/control_client.go
Normal file
40
cmd/cli/control_client.go
Normal file
@@ -0,0 +1,40 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
// controlClient represents an HTTP client for communicating with the control server
|
||||
type controlClient struct {
|
||||
c *http.Client
|
||||
}
|
||||
|
||||
// newControlClient creates a new control client with Unix socket transport
|
||||
func newControlClient(addr string) *controlClient {
|
||||
return &controlClient{c: &http.Client{
|
||||
Transport: &http.Transport{
|
||||
DialContext: func(ctx context.Context, _, _ string) (net.Conn, error) {
|
||||
d := net.Dialer{}
|
||||
return d.DialContext(ctx, "unix", addr)
|
||||
},
|
||||
},
|
||||
Timeout: time.Second * 30,
|
||||
}}
|
||||
}
|
||||
|
||||
func (c *controlClient) post(path string, data io.Reader) (*http.Response, error) {
|
||||
// for log/send, set the timeout to 5 minutes
|
||||
if path == sendLogsPath {
|
||||
c.c.Timeout = time.Minute * 5
|
||||
}
|
||||
return c.c.Post("http://unix"+path, contentTypeJson, data)
|
||||
}
|
||||
|
||||
// deactivationRequest represents request for validating deactivation pin.
|
||||
type deactivationRequest struct {
|
||||
Pin int64 `json:"pin"`
|
||||
}
|
||||
349
cmd/cli/control_server.go
Normal file
349
cmd/cli/control_server.go
Normal file
@@ -0,0 +1,349 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"reflect"
|
||||
"sort"
|
||||
"time"
|
||||
|
||||
"github.com/kardianos/service"
|
||||
dto "github.com/prometheus/client_model/go"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld"
|
||||
"github.com/Control-D-Inc/ctrld/internal/controld"
|
||||
)
|
||||
|
||||
const (
|
||||
contentTypeJson = "application/json"
|
||||
listClientsPath = "/clients"
|
||||
startedPath = "/started"
|
||||
reloadPath = "/reload"
|
||||
deactivationPath = "/deactivation"
|
||||
cdPath = "/cd"
|
||||
ifacePath = "/iface"
|
||||
viewLogsPath = "/log/view"
|
||||
sendLogsPath = "/log/send"
|
||||
)
|
||||
|
||||
type ifaceResponse struct {
|
||||
Name string `json:"name"`
|
||||
All bool `json:"all"`
|
||||
OK bool `json:"ok"`
|
||||
}
|
||||
|
||||
// controlServer represents an HTTP server for handling control requests
|
||||
type controlServer struct {
|
||||
server *http.Server
|
||||
mux *http.ServeMux
|
||||
addr string
|
||||
}
|
||||
|
||||
// newControlServer creates a new control server instance
|
||||
func newControlServer(addr string) (*controlServer, error) {
|
||||
mux := http.NewServeMux()
|
||||
s := &controlServer{
|
||||
server: &http.Server{Handler: mux},
|
||||
mux: mux,
|
||||
}
|
||||
s.addr = addr
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (s *controlServer) start() error {
|
||||
_ = os.Remove(s.addr)
|
||||
unixListener, err := net.Listen("unix", s.addr)
|
||||
if l, ok := unixListener.(*net.UnixListener); ok {
|
||||
l.SetUnlinkOnClose(true)
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
go s.server.Serve(unixListener)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *controlServer) stop() error {
|
||||
_ = os.Remove(s.addr)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*2)
|
||||
defer cancel()
|
||||
return s.server.Shutdown(ctx)
|
||||
}
|
||||
|
||||
func (s *controlServer) register(pattern string, handler http.Handler) {
|
||||
s.mux.Handle(pattern, jsonResponse(handler))
|
||||
}
|
||||
|
||||
func (p *prog) registerControlServerHandler() {
|
||||
p.cs.register(listClientsPath, http.HandlerFunc(func(w http.ResponseWriter, request *http.Request) {
|
||||
p.Debug().Msg("Handling list clients request")
|
||||
|
||||
clients := p.ciTable.ListClients()
|
||||
p.Debug().Int("client_count", len(clients)).Msg("Retrieved clients list")
|
||||
|
||||
sort.Slice(clients, func(i, j int) bool {
|
||||
return clients[i].IP.Less(clients[j].IP)
|
||||
})
|
||||
p.Debug().Msg("Sorted clients by IP address")
|
||||
|
||||
if p.metricsQueryStats.Load() {
|
||||
p.Debug().Msg("Metrics query stats enabled, collecting query counts")
|
||||
|
||||
for idx, client := range clients {
|
||||
p.Debug().
|
||||
Int("index", idx).
|
||||
Str("ip", client.IP.String()).
|
||||
Str("mac", client.Mac).
|
||||
Str("hostname", client.Hostname).
|
||||
Msg("Processing client metrics")
|
||||
|
||||
client.IncludeQueryCount = true
|
||||
dm := &dto.Metric{}
|
||||
|
||||
if statsClientQueriesCount.MetricVec == nil {
|
||||
p.Debug().
|
||||
Str("client_ip", client.IP.String()).
|
||||
Msg("Skipping metrics collection: MetricVec is nil")
|
||||
continue
|
||||
}
|
||||
|
||||
m, err := statsClientQueriesCount.MetricVec.GetMetricWithLabelValues(
|
||||
client.IP.String(),
|
||||
client.Mac,
|
||||
client.Hostname,
|
||||
)
|
||||
if err != nil {
|
||||
p.Debug().
|
||||
Err(err).
|
||||
Str("client_ip", client.IP.String()).
|
||||
Str("mac", client.Mac).
|
||||
Str("hostname", client.Hostname).
|
||||
Msg("Failed to get metrics for client")
|
||||
continue
|
||||
}
|
||||
|
||||
if err := m.Write(dm); err == nil && dm.Counter != nil {
|
||||
client.QueryCount = int64(dm.Counter.GetValue())
|
||||
p.Debug().
|
||||
Str("client_ip", client.IP.String()).
|
||||
Int64("query_count", client.QueryCount).
|
||||
Msg("Successfully collected query count")
|
||||
} else if err != nil {
|
||||
p.Debug().
|
||||
Err(err).
|
||||
Str("client_ip", client.IP.String()).
|
||||
Msg("Failed to write metric")
|
||||
}
|
||||
}
|
||||
} else {
|
||||
p.Debug().Msg("Metrics query stats disabled, skipping query counts")
|
||||
}
|
||||
|
||||
if err := json.NewEncoder(w).Encode(&clients); err != nil {
|
||||
p.Error().
|
||||
Err(err).
|
||||
Int("client_count", len(clients)).
|
||||
Msg("Failed to encode clients response")
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
p.Debug().
|
||||
Int("client_count", len(clients)).
|
||||
Msg("Successfully sent clients list response")
|
||||
}))
|
||||
p.cs.register(startedPath, http.HandlerFunc(func(w http.ResponseWriter, request *http.Request) {
|
||||
select {
|
||||
case <-p.onStartedDone:
|
||||
w.WriteHeader(http.StatusOK)
|
||||
case <-time.After(10 * time.Second):
|
||||
w.WriteHeader(http.StatusRequestTimeout)
|
||||
}
|
||||
}))
|
||||
p.cs.register(reloadPath, http.HandlerFunc(func(w http.ResponseWriter, request *http.Request) {
|
||||
listeners := make(map[string]*ctrld.ListenerConfig)
|
||||
p.mu.Lock()
|
||||
for k, v := range p.cfg.Listener {
|
||||
listeners[k] = &ctrld.ListenerConfig{
|
||||
IP: v.IP,
|
||||
Port: v.Port,
|
||||
}
|
||||
}
|
||||
oldSvc := p.cfg.Service
|
||||
p.mu.Unlock()
|
||||
if err := p.sendReloadSignal(); err != nil {
|
||||
p.Error().Err(err).Msg("Could not send reload signal")
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
select {
|
||||
case <-p.reloadDoneCh:
|
||||
case <-time.After(5 * time.Second):
|
||||
http.Error(w, "Timeout waiting for ctrld reload", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
// Checking for cases that we could not do a reload.
|
||||
|
||||
// 1. Listener config ip or port changes.
|
||||
for k, v := range p.cfg.Listener {
|
||||
l := listeners[k]
|
||||
if l == nil || l.IP != v.IP || l.Port != v.Port {
|
||||
w.WriteHeader(http.StatusCreated)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// 2. Service config changes.
|
||||
if !reflect.DeepEqual(oldSvc, p.cfg.Service) {
|
||||
w.WriteHeader(http.StatusCreated)
|
||||
return
|
||||
}
|
||||
|
||||
// Otherwise, reload is done.
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
p.cs.register(deactivationPath, http.HandlerFunc(func(w http.ResponseWriter, request *http.Request) {
|
||||
// Non-cd mode always allowing deactivation.
|
||||
if cdUID == "" {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
return
|
||||
}
|
||||
|
||||
loggerCtx := ctrld.LoggerCtx(context.Background(), p.logger.Load())
|
||||
// Re-fetch pin code from API.
|
||||
if rc, err := controld.FetchResolverConfig(loggerCtx, cdUID, appVersion, cdDev); rc != nil {
|
||||
if rc.DeactivationPin != nil {
|
||||
cdDeactivationPin.Store(*rc.DeactivationPin)
|
||||
} else {
|
||||
cdDeactivationPin.Store(defaultDeactivationPin)
|
||||
}
|
||||
} else {
|
||||
p.Warn().Err(err).Msg("Could not re-fetch deactivation pin code")
|
||||
}
|
||||
|
||||
// If pin code not set, allowing deactivation.
|
||||
if !deactivationPinSet() {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
return
|
||||
}
|
||||
|
||||
var req deactivationRequest
|
||||
if err := json.NewDecoder(request.Body).Decode(&req); err != nil {
|
||||
w.WriteHeader(http.StatusPreconditionFailed)
|
||||
p.Error().Err(err).Msg("Invalid deactivation request")
|
||||
return
|
||||
}
|
||||
|
||||
code := http.StatusForbidden
|
||||
switch req.Pin {
|
||||
case cdDeactivationPin.Load():
|
||||
code = http.StatusOK
|
||||
select {
|
||||
case p.pinCodeValidCh <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
case defaultDeactivationPin:
|
||||
// If the pin code was set, but users do not provide --pin, return proper code to client.
|
||||
code = http.StatusBadRequest
|
||||
}
|
||||
w.WriteHeader(code)
|
||||
}))
|
||||
p.cs.register(cdPath, http.HandlerFunc(func(w http.ResponseWriter, request *http.Request) {
|
||||
if cdUID != "" {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(cdUID))
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
}))
|
||||
p.cs.register(ifacePath, http.HandlerFunc(func(w http.ResponseWriter, request *http.Request) {
|
||||
res := &ifaceResponse{Name: iface}
|
||||
// p.setDNS is only called when running as a service
|
||||
if !service.Interactive() {
|
||||
<-p.csSetDnsDone
|
||||
if p.csSetDnsOk {
|
||||
res.Name = p.runningIface
|
||||
res.All = p.requiredMultiNICsConfig
|
||||
res.OK = true
|
||||
}
|
||||
}
|
||||
if err := json.NewEncoder(w).Encode(res); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
http.Error(w, fmt.Sprintf("could not marshal iface data: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
}))
|
||||
p.cs.register(viewLogsPath, http.HandlerFunc(func(w http.ResponseWriter, request *http.Request) {
|
||||
lr, err := p.logReaderRaw()
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
defer lr.r.Close()
|
||||
if lr.size == 0 {
|
||||
w.WriteHeader(http.StatusMovedPermanently)
|
||||
return
|
||||
}
|
||||
data, err := io.ReadAll(lr.r)
|
||||
if err != nil {
|
||||
http.Error(w, fmt.Sprintf("could not read log: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
if err := json.NewEncoder(w).Encode(&logViewResponse{Data: string(data)}); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
http.Error(w, fmt.Sprintf("could not marshal log data: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
}))
|
||||
p.cs.register(sendLogsPath, http.HandlerFunc(func(w http.ResponseWriter, request *http.Request) {
|
||||
if time.Since(p.internalLogSent) < logWriterSentInterval {
|
||||
w.WriteHeader(http.StatusServiceUnavailable)
|
||||
return
|
||||
}
|
||||
r, err := p.logReaderNoColor()
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if r.size == 0 {
|
||||
w.WriteHeader(http.StatusMovedPermanently)
|
||||
return
|
||||
}
|
||||
req := &controld.LogsRequest{
|
||||
UID: cdUID,
|
||||
Data: r.r,
|
||||
}
|
||||
p.Debug().Msg("Sending log file to ControlD server")
|
||||
resp := logSentResponse{Size: r.size}
|
||||
loggerCtx := ctrld.LoggerCtx(context.Background(), p.logger.Load())
|
||||
if err := controld.SendLogs(loggerCtx, req, cdDev); err != nil {
|
||||
p.Error().Msgf("Could not send log file to ControlD server: %v", err)
|
||||
resp.Error = err.Error()
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
} else {
|
||||
p.Debug().Msg("Sending log file successfully")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}
|
||||
if err := json.NewEncoder(w).Encode(&resp); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
}
|
||||
p.internalLogSent = time.Now()
|
||||
}))
|
||||
}
|
||||
|
||||
// jsonResponse wraps an HTTP handler to set JSON content type
|
||||
func jsonResponse(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
54
cmd/cli/control_server_test.go
Normal file
54
cmd/cli/control_server_test.go
Normal file
@@ -0,0 +1,54 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestControlServer(t *testing.T) {
|
||||
f, err := os.CreateTemp("", "")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer os.Remove(f.Name())
|
||||
f.Close()
|
||||
|
||||
s, err := newControlServer(f.Name())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
pattern := "/ping"
|
||||
respBody := []byte("pong")
|
||||
s.register(pattern, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_, _ = w.Write(respBody)
|
||||
}))
|
||||
if err := s.start(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
c := newControlClient(f.Name())
|
||||
resp, err := c.post(pattern, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != 200 {
|
||||
t.Fatalf("unepxected response code: %d", resp.StatusCode)
|
||||
}
|
||||
if ct := resp.Header.Get("content-type"); ct != contentTypeJson {
|
||||
t.Fatalf("unexpected content type: %s", ct)
|
||||
}
|
||||
buf, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !bytes.Equal(buf, respBody) {
|
||||
t.Errorf("unexpected response body, want: %q, got: %q", string(respBody), string(buf))
|
||||
}
|
||||
if err := s.stop(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
1921
cmd/cli/dns_proxy.go
Normal file
1921
cmd/cli/dns_proxy.go
Normal file
File diff suppressed because it is too large
Load Diff
803
cmd/cli/dns_proxy_test.go
Normal file
803
cmd/cli/dns_proxy_test.go
Normal file
@@ -0,0 +1,803 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld"
|
||||
"github.com/Control-D-Inc/ctrld/internal/dnscache"
|
||||
"github.com/Control-D-Inc/ctrld/testhelper"
|
||||
)
|
||||
|
||||
func Test_wildcardMatches(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
wildcard string
|
||||
domain string
|
||||
match bool
|
||||
}{
|
||||
{"domain - prefix parent should not match", "*.windscribe.com", "windscribe.com", false},
|
||||
{"domain - prefix", "*.windscribe.com", "anything.windscribe.com", true},
|
||||
{"domain - prefix not match other s", "*.windscribe.com", "example.com", false},
|
||||
{"domain - prefix not match s in name", "*.windscribe.com", "wwindscribe.com", false},
|
||||
{"domain - suffix", "suffix.*", "suffix.windscribe.com", true},
|
||||
{"domain - suffix not match other", "suffix.*", "suffix1.windscribe.com", false},
|
||||
{"domain - both", "suffix.*.windscribe.com", "suffix.anything.windscribe.com", true},
|
||||
{"domain - both not match", "suffix.*.windscribe.com", "suffix1.suffix.windscribe.com", false},
|
||||
{"domain - case-insensitive", "*.WINDSCRIBE.com", "anything.windscribe.com", true},
|
||||
{"mac - prefix", "*:98:05:b4:2b", "d4:67:98:05:b4:2b", true},
|
||||
{"mac - prefix not match other s", "*:98:05:b4:2b", "0d:ba:54:09:94:2c", false},
|
||||
{"mac - prefix not match s in name", "*:98:05:b4:2b", "e4:67:97:05:b4:2b", false},
|
||||
{"mac - suffix", "d4:67:98:*", "d4:67:98:05:b4:2b", true},
|
||||
{"mac - suffix not match other", "d4:67:98:*", "d4:67:97:15:b4:2b", false},
|
||||
{"mac - both", "d4:67:98:*:b4:2b", "d4:67:98:05:b4:2b", true},
|
||||
{"mac - both not match", "d4:67:98:*:b4:2b", "d4:67:97:05:c4:2b", false},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
if got := wildcardMatches(tc.wildcard, tc.domain); got != tc.match {
|
||||
t.Errorf("unexpected result, wildcard: %s, domain: %s, want: %v, got: %v", tc.wildcard, tc.domain, tc.match, got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_canonicalName(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
domain string
|
||||
canonical string
|
||||
}{
|
||||
{"fqdn to canonical", "windscribe.com.", "windscribe.com"},
|
||||
{"already canonical", "windscribe.com", "windscribe.com"},
|
||||
{"case insensitive", "Windscribe.Com.", "windscribe.com"},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
if got := canonicalName(tc.domain); got != tc.canonical {
|
||||
t.Errorf("unexpected result, want: %s, got: %s", tc.canonical, got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_prog_upstreamFor(t *testing.T) {
|
||||
cfg := testhelper.SampleConfig(t)
|
||||
cfg.Service.LeakOnUpstreamFailure = func(v bool) *bool { return &v }(false)
|
||||
p := &prog{cfg: cfg}
|
||||
p.logger.Store(mainLog.Load())
|
||||
p.um = newUpstreamMonitor(p.cfg, mainLog.Load())
|
||||
p.lanLoopGuard = newLoopGuard()
|
||||
p.ptrLoopGuard = newLoopGuard()
|
||||
for _, nc := range p.cfg.Network {
|
||||
for _, cidr := range nc.Cidrs {
|
||||
_, ipNet, err := net.ParseCIDR(cidr)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
nc.IPNets = append(nc.IPNets, ipNet)
|
||||
}
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
ip string
|
||||
mac string
|
||||
defaultUpstreamNum string
|
||||
lc *ctrld.ListenerConfig
|
||||
domain string
|
||||
upstreams []string
|
||||
matched bool
|
||||
testLogMsg string
|
||||
}{
|
||||
{"Policy map matches", "192.168.0.1:0", "", "0", p.cfg.Listener["0"], "abc.xyz", []string{"upstream.1", "upstream.0"}, true, ""},
|
||||
{"Policy split matches", "192.168.0.1:0", "", "0", p.cfg.Listener["0"], "abc.ru", []string{"upstream.1"}, true, ""},
|
||||
{"Policy map for other network matches", "192.168.1.2:0", "", "0", p.cfg.Listener["0"], "abc.xyz", []string{"upstream.0"}, true, ""},
|
||||
{"No policy map for listener", "192.168.1.2:0", "", "1", p.cfg.Listener["1"], "abc.ru", []string{"upstream.1"}, false, ""},
|
||||
{"unenforced loging", "192.168.1.2:0", "", "0", p.cfg.Listener["0"], "abc.ru", []string{"upstream.1"}, true, "My Policy, network.1 (unenforced), *.ru -> [upstream.1]"},
|
||||
{"Policy Macs matches upper", "192.168.0.1:0", "14:45:A0:67:83:0A", "0", p.cfg.Listener["0"], "abc.xyz", []string{"upstream.2"}, true, "14:45:a0:67:83:0a"},
|
||||
{"Policy Macs matches lower", "192.168.0.1:0", "14:54:4a:8e:08:2d", "0", p.cfg.Listener["0"], "abc.xyz", []string{"upstream.2"}, true, "14:54:4a:8e:08:2d"},
|
||||
{"Policy Macs matches case-insensitive", "192.168.0.1:0", "14:54:4A:8E:08:2D", "0", p.cfg.Listener["0"], "abc.xyz", []string{"upstream.2"}, true, "14:54:4a:8e:08:2d"},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
for _, network := range []string{"udp", "tcp"} {
|
||||
var (
|
||||
addr net.Addr
|
||||
err error
|
||||
)
|
||||
switch network {
|
||||
case "udp":
|
||||
addr, err = net.ResolveUDPAddr(network, tc.ip)
|
||||
case "tcp":
|
||||
addr, err = net.ResolveTCPAddr(network, tc.ip)
|
||||
}
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, addr)
|
||||
ctx := context.WithValue(context.Background(), ctrld.ReqIdCtxKey{}, requestID())
|
||||
ufr := p.upstreamFor(ctx, tc.defaultUpstreamNum, tc.lc, addr, tc.mac, tc.domain)
|
||||
p.proxy(ctx, &proxyRequest{
|
||||
msg: newDnsMsgWithHostname("foo", dns.TypeA),
|
||||
ufr: ufr,
|
||||
})
|
||||
assert.Equal(t, tc.matched, ufr.matched)
|
||||
assert.Equal(t, tc.upstreams, ufr.upstreams)
|
||||
if tc.testLogMsg != "" {
|
||||
assert.Contains(t, logOutput.String(), tc.testLogMsg)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_prog_upstreamForWithCustomMatching(t *testing.T) {
|
||||
cfg := testhelper.SampleConfig(t)
|
||||
prog := &prog{cfg: cfg}
|
||||
prog.logger.Store(mainLog.Load())
|
||||
for _, nc := range prog.cfg.Network {
|
||||
for _, cidr := range nc.Cidrs {
|
||||
_, ipNet, err := net.ParseCIDR(cidr)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
nc.IPNets = append(nc.IPNets, ipNet)
|
||||
}
|
||||
}
|
||||
|
||||
// Create a custom policy with domain-first matching order
|
||||
customPolicy := &ctrld.ListenerPolicyConfig{
|
||||
Name: "Custom Policy",
|
||||
Networks: []ctrld.Rule{
|
||||
{"network.0": []string{"upstream.1", "upstream.0"}},
|
||||
},
|
||||
Macs: []ctrld.Rule{
|
||||
{"14:45:A0:67:83:0A": []string{"upstream.2"}},
|
||||
},
|
||||
Rules: []ctrld.Rule{
|
||||
{"*.ru": []string{"upstream.1"}},
|
||||
},
|
||||
Matching: &ctrld.MatchingConfig{
|
||||
Order: []string{"domain", "mac", "network"},
|
||||
},
|
||||
}
|
||||
|
||||
customListener := &ctrld.ListenerConfig{
|
||||
Policy: customPolicy,
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
ip string
|
||||
mac string
|
||||
domain string
|
||||
upstreams []string
|
||||
matched bool
|
||||
}{
|
||||
{
|
||||
name: "Domain rule should match first with custom order",
|
||||
ip: "192.168.0.1:0",
|
||||
mac: "14:45:A0:67:83:0A",
|
||||
domain: "example.ru",
|
||||
upstreams: []string{"upstream.1"},
|
||||
matched: true,
|
||||
},
|
||||
{
|
||||
name: "MAC rule should match when no domain rule",
|
||||
ip: "192.168.0.1:0",
|
||||
mac: "14:45:A0:67:83:0A",
|
||||
domain: "example.com",
|
||||
upstreams: []string{"upstream.2"},
|
||||
matched: true,
|
||||
},
|
||||
{
|
||||
name: "Network rule should match when no domain or MAC rule",
|
||||
ip: "192.168.0.1:0",
|
||||
mac: "00:11:22:33:44:55",
|
||||
domain: "example.com",
|
||||
upstreams: []string{"upstream.1", "upstream.0"},
|
||||
matched: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
addr, err := net.ResolveUDPAddr("udp", tc.ip)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, addr)
|
||||
|
||||
ctx := context.WithValue(context.Background(), ctrld.ReqIdCtxKey{}, requestID())
|
||||
ufr := prog.upstreamFor(ctx, "0", customListener, addr, tc.mac, tc.domain)
|
||||
|
||||
assert.Equal(t, tc.matched, ufr.matched)
|
||||
assert.Equal(t, tc.upstreams, ufr.upstreams)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCache(t *testing.T) {
|
||||
cfg := testhelper.SampleConfig(t)
|
||||
prog := &prog{cfg: cfg}
|
||||
prog.logger.Store(mainLog.Load())
|
||||
for _, nc := range prog.cfg.Network {
|
||||
for _, cidr := range nc.Cidrs {
|
||||
_, ipNet, err := net.ParseCIDR(cidr)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
nc.IPNets = append(nc.IPNets, ipNet)
|
||||
}
|
||||
}
|
||||
cacher, err := dnscache.NewLRUCache(4096)
|
||||
require.NoError(t, err)
|
||||
prog.cache = cacher
|
||||
|
||||
msg := new(dns.Msg)
|
||||
msg.SetQuestion("example.com", dns.TypeA)
|
||||
msg.MsgHdr.RecursionDesired = true
|
||||
answer1 := new(dns.Msg)
|
||||
answer1.SetRcode(msg, dns.RcodeSuccess)
|
||||
|
||||
prog.cache.Add(dnscache.NewKey(msg, "upstream.1"), dnscache.NewValue(answer1, time.Now().Add(time.Minute)))
|
||||
answer2 := new(dns.Msg)
|
||||
answer2.SetRcode(msg, dns.RcodeRefused)
|
||||
prog.cache.Add(dnscache.NewKey(msg, "upstream.0"), dnscache.NewValue(answer2, time.Now().Add(time.Minute)))
|
||||
|
||||
req1 := &proxyRequest{
|
||||
msg: msg,
|
||||
ci: nil,
|
||||
failoverRcodes: nil,
|
||||
ufr: &upstreamForResult{
|
||||
upstreams: []string{"upstream.1"},
|
||||
matchedPolicy: "",
|
||||
matchedNetwork: "",
|
||||
matchedRule: "",
|
||||
matched: false,
|
||||
},
|
||||
}
|
||||
req2 := &proxyRequest{
|
||||
msg: msg,
|
||||
ci: nil,
|
||||
failoverRcodes: nil,
|
||||
ufr: &upstreamForResult{
|
||||
upstreams: []string{"upstream.0"},
|
||||
matchedPolicy: "",
|
||||
matchedNetwork: "",
|
||||
matchedRule: "",
|
||||
matched: false,
|
||||
},
|
||||
}
|
||||
got1 := prog.proxy(context.Background(), req1)
|
||||
got2 := prog.proxy(context.Background(), req2)
|
||||
assert.NotSame(t, got1, got2)
|
||||
assert.Equal(t, answer1.Rcode, got1.answer.Rcode)
|
||||
assert.Equal(t, answer2.Rcode, got2.answer.Rcode)
|
||||
}
|
||||
|
||||
func Test_ipAndMacFromMsg(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
ip string
|
||||
wantIp bool
|
||||
mac string
|
||||
wantMac bool
|
||||
}{
|
||||
{"has ip v4 and mac", "1.2.3.4", true, "4c:20:b8:ab:87:1b", true},
|
||||
{"has ip v6 and mac", "2606:1a40:3::1", true, "4c:20:b8:ab:87:1b", true},
|
||||
{"no ip", "1.2.3.4", false, "4c:20:b8:ab:87:1b", false},
|
||||
{"no mac", "1.2.3.4", false, "4c:20:b8:ab:87:1b", false},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ip := net.ParseIP(tc.ip)
|
||||
if ip == nil {
|
||||
t.Fatal("missing IP")
|
||||
}
|
||||
hw, err := net.ParseMAC(tc.mac)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
m := new(dns.Msg)
|
||||
m.SetQuestion("example.com.", dns.TypeA)
|
||||
o := &dns.OPT{Hdr: dns.RR_Header{Name: ".", Rrtype: dns.TypeOPT}}
|
||||
if tc.wantMac {
|
||||
ec1 := &dns.EDNS0_LOCAL{Code: EDNS0_OPTION_MAC, Data: hw}
|
||||
o.Option = append(o.Option, ec1)
|
||||
}
|
||||
if tc.wantIp {
|
||||
ec2 := &dns.EDNS0_SUBNET{Address: ip}
|
||||
o.Option = append(o.Option, ec2)
|
||||
}
|
||||
m.Extra = append(m.Extra, o)
|
||||
gotIP, gotMac := ipAndMacFromMsg(m)
|
||||
if tc.wantMac && gotMac != tc.mac {
|
||||
t.Errorf("mismatch, want: %q, got: %q", tc.mac, gotMac)
|
||||
}
|
||||
if !tc.wantMac && gotMac != "" {
|
||||
t.Errorf("unexpected mac: %q", gotMac)
|
||||
}
|
||||
if tc.wantIp && gotIP != tc.ip {
|
||||
t.Errorf("mismatch, want: %q, got: %q", tc.ip, gotIP)
|
||||
}
|
||||
if !tc.wantIp && gotIP != "" {
|
||||
t.Errorf("unexpected ip: %q", gotIP)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_remoteAddrFromMsg(t *testing.T) {
|
||||
loopbackIP := net.ParseIP("127.0.0.1")
|
||||
tests := []struct {
|
||||
name string
|
||||
addr net.Addr
|
||||
ci *ctrld.ClientInfo
|
||||
want string
|
||||
}{
|
||||
{"tcp", &net.TCPAddr{IP: loopbackIP, Port: 12345}, &ctrld.ClientInfo{IP: "192.168.1.10"}, "192.168.1.10:12345"},
|
||||
{"udp", &net.UDPAddr{IP: loopbackIP, Port: 12345}, &ctrld.ClientInfo{IP: "192.168.1.11"}, "192.168.1.11:12345"},
|
||||
{"nil client info", &net.UDPAddr{IP: loopbackIP, Port: 12345}, nil, "127.0.0.1:12345"},
|
||||
{"empty ip", &net.UDPAddr{IP: loopbackIP, Port: 12345}, &ctrld.ClientInfo{}, "127.0.0.1:12345"},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
addr := spoofRemoteAddr(tc.addr, tc.ci)
|
||||
if addr.String() != tc.want {
|
||||
t.Errorf("unexpected result, want: %q, got: %q", tc.want, addr.String())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_ipFromARPA(t *testing.T) {
|
||||
tests := []struct {
|
||||
IP string
|
||||
ARPA string
|
||||
}{
|
||||
{"1.2.3.4", "4.3.2.1.in-addr.arpa."},
|
||||
{"245.110.36.114", "114.36.110.245.in-addr.arpa."},
|
||||
{"::ffff:12.34.56.78", "78.56.34.12.in-addr.arpa."},
|
||||
{"::1", "1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.ip6.arpa."},
|
||||
{"1::", "0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.1.0.0.0.ip6.arpa."},
|
||||
{"1234:567::89a:bcde", "e.d.c.b.a.9.8.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.7.6.5.0.4.3.2.1.ip6.arpa."},
|
||||
{"1234:567:fefe:bcbc:adad:9e4a:89a:bcde", "e.d.c.b.a.9.8.0.a.4.e.9.d.a.d.a.c.b.c.b.e.f.e.f.7.6.5.0.4.3.2.1.ip6.arpa."},
|
||||
{"", "asd.in-addr.arpa."},
|
||||
{"", "asd.ip6.arpa."},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.IP, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
if got := ipFromARPA(tc.ARPA); !got.Equal(net.ParseIP(tc.IP)) {
|
||||
t.Errorf("unexpected ip, want: %s, got: %s", tc.IP, got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func newDnsMsgWithClientIP(ip string) *dns.Msg {
|
||||
m := new(dns.Msg)
|
||||
m.SetQuestion("example.com.", dns.TypeA)
|
||||
o := &dns.OPT{Hdr: dns.RR_Header{Name: ".", Rrtype: dns.TypeOPT}}
|
||||
o.Option = append(o.Option, &dns.EDNS0_SUBNET{Address: net.ParseIP(ip)})
|
||||
m.Extra = append(m.Extra, o)
|
||||
return m
|
||||
}
|
||||
|
||||
func Test_stripClientSubnet(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
msg *dns.Msg
|
||||
wantSubnet bool
|
||||
}{
|
||||
{"no edns0", new(dns.Msg), false},
|
||||
{"loopback IP v4", newDnsMsgWithClientIP("127.0.0.1"), false},
|
||||
{"loopback IP v6", newDnsMsgWithClientIP("::1"), false},
|
||||
{"private IP v4", newDnsMsgWithClientIP("192.168.1.123"), false},
|
||||
{"private IP v6", newDnsMsgWithClientIP("fd12:3456:789a:1::1"), false},
|
||||
{"public IP", newDnsMsgWithClientIP("1.1.1.1"), true},
|
||||
{"invalid IP", newDnsMsgWithClientIP(""), true},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
stripClientSubnet(tc.msg)
|
||||
hasSubnet := false
|
||||
if opt := tc.msg.IsEdns0(); opt != nil {
|
||||
for _, s := range opt.Option {
|
||||
if _, ok := s.(*dns.EDNS0_SUBNET); ok {
|
||||
hasSubnet = true
|
||||
}
|
||||
}
|
||||
}
|
||||
if tc.wantSubnet != hasSubnet {
|
||||
t.Errorf("unexpected result, want: %v, got: %v", tc.wantSubnet, hasSubnet)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func newDnsMsgWithHostname(hostname string, typ uint16) *dns.Msg {
|
||||
m := new(dns.Msg)
|
||||
m.SetQuestion(hostname, typ)
|
||||
return m
|
||||
}
|
||||
|
||||
func Test_isLanHostnameQuery(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
msg *dns.Msg
|
||||
isLanHostnameQuery bool
|
||||
}{
|
||||
{"A", newDnsMsgWithHostname("foo", dns.TypeA), true},
|
||||
{"AAAA", newDnsMsgWithHostname("foo", dns.TypeAAAA), true},
|
||||
{"A not LAN", newDnsMsgWithHostname("example.com", dns.TypeA), false},
|
||||
{"AAAA not LAN", newDnsMsgWithHostname("example.com", dns.TypeAAAA), false},
|
||||
{"Not A or AAAA", newDnsMsgWithHostname("foo", dns.TypeTXT), false},
|
||||
{".domain", newDnsMsgWithHostname("foo.domain", dns.TypeA), true},
|
||||
{".lan", newDnsMsgWithHostname("foo.lan", dns.TypeA), true},
|
||||
{".local", newDnsMsgWithHostname("foo.local", dns.TypeA), true},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
if got := isLanHostnameQuery(tc.msg); tc.isLanHostnameQuery != got {
|
||||
t.Errorf("unexpected result, want: %v, got: %v", tc.isLanHostnameQuery, got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func newDnsMsgPtr(ip string, t *testing.T) *dns.Msg {
|
||||
t.Helper()
|
||||
m := new(dns.Msg)
|
||||
ptr, err := dns.ReverseAddr(ip)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
m.SetQuestion(ptr, dns.TypePTR)
|
||||
return m
|
||||
}
|
||||
|
||||
func Test_isPrivatePtrLookup(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
msg *dns.Msg
|
||||
isPrivatePtrLookup bool
|
||||
}{
|
||||
// RFC 1918 allocates 10.0.0.0/8, 172.16.0.0/12, and 192.168.0.0/16 as
|
||||
{"10.0.0.0/8", newDnsMsgPtr("10.0.0.123", t), true},
|
||||
{"172.16.0.0/12", newDnsMsgPtr("172.16.0.123", t), true},
|
||||
{"192.168.0.0/16", newDnsMsgPtr("192.168.1.123", t), true},
|
||||
{"CGNAT", newDnsMsgPtr("100.66.27.28", t), true},
|
||||
{"Loopback", newDnsMsgPtr("127.0.0.1", t), true},
|
||||
{"Link Local Unicast", newDnsMsgPtr("fe80::69f6:e16e:8bdb:433f", t), true},
|
||||
{"Public IP", newDnsMsgPtr("8.8.8.8", t), false},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
if got := isPrivatePtrLookup(tc.msg); tc.isPrivatePtrLookup != got {
|
||||
t.Errorf("unexpected result, want: %v, got: %v", tc.isPrivatePtrLookup, got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_isSrvLanLookup(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
msg *dns.Msg
|
||||
isSrvLookup bool
|
||||
}{
|
||||
{"SRV LAN", newDnsMsgWithHostname("foo", dns.TypeSRV), true},
|
||||
{"Not SRV", newDnsMsgWithHostname("foo", dns.TypeNone), false},
|
||||
{"Not SRV LAN", newDnsMsgWithHostname("controld.com", dns.TypeSRV), false},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
if got := isSrvLanLookup(tc.msg); tc.isSrvLookup != got {
|
||||
t.Errorf("unexpected result, want: %v, got: %v", tc.isSrvLookup, got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_isWanClient(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
addr net.Addr
|
||||
isWanClient bool
|
||||
}{
|
||||
// RFC 1918 allocates 10.0.0.0/8, 172.16.0.0/12, and 192.168.0.0/16 as
|
||||
{"10.0.0.0/8", &net.UDPAddr{IP: net.ParseIP("10.0.0.123")}, false},
|
||||
{"172.16.0.0/12", &net.UDPAddr{IP: net.ParseIP("172.16.0.123")}, false},
|
||||
{"192.168.0.0/16", &net.UDPAddr{IP: net.ParseIP("192.168.1.123")}, false},
|
||||
{"CGNAT", &net.UDPAddr{IP: net.ParseIP("100.66.27.28")}, false},
|
||||
{"Loopback", &net.UDPAddr{IP: net.ParseIP("127.0.0.1")}, false},
|
||||
{"Link Local Unicast", &net.UDPAddr{IP: net.ParseIP("fe80::69f6:e16e:8bdb:433f")}, false},
|
||||
{"Public", &net.UDPAddr{IP: net.ParseIP("8.8.8.8")}, true},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
if got := isWanClient(tc.addr); tc.isWanClient != got {
|
||||
t.Errorf("unexpected result, want: %v, got: %v", tc.isWanClient, got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_shouldStartRecovery(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
reason RecoveryReason
|
||||
hasExistingRecovery bool
|
||||
expectedResult bool
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "network change with existing recovery",
|
||||
reason: RecoveryReasonNetworkChange,
|
||||
hasExistingRecovery: true,
|
||||
expectedResult: true,
|
||||
description: "should cancel existing recovery and start new one for network change",
|
||||
},
|
||||
{
|
||||
name: "network change without existing recovery",
|
||||
reason: RecoveryReasonNetworkChange,
|
||||
hasExistingRecovery: false,
|
||||
expectedResult: true,
|
||||
description: "should start new recovery for network change",
|
||||
},
|
||||
{
|
||||
name: "regular failure with existing recovery",
|
||||
reason: RecoveryReasonRegularFailure,
|
||||
hasExistingRecovery: true,
|
||||
expectedResult: false,
|
||||
description: "should skip duplicate recovery for regular failure",
|
||||
},
|
||||
{
|
||||
name: "regular failure without existing recovery",
|
||||
reason: RecoveryReasonRegularFailure,
|
||||
hasExistingRecovery: false,
|
||||
expectedResult: true,
|
||||
description: "should start new recovery for regular failure",
|
||||
},
|
||||
{
|
||||
name: "OS failure with existing recovery",
|
||||
reason: RecoveryReasonOSFailure,
|
||||
hasExistingRecovery: true,
|
||||
expectedResult: false,
|
||||
description: "should skip duplicate recovery for OS failure",
|
||||
},
|
||||
{
|
||||
name: "OS failure without existing recovery",
|
||||
reason: RecoveryReasonOSFailure,
|
||||
hasExistingRecovery: false,
|
||||
expectedResult: true,
|
||||
description: "should start new recovery for OS failure",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
p := newTestProg(t)
|
||||
|
||||
// Setup existing recovery if needed
|
||||
if tc.hasExistingRecovery {
|
||||
p.recoveryCancelMu.Lock()
|
||||
p.recoveryCancel = func() {} // Mock cancel function
|
||||
p.recoveryCancelMu.Unlock()
|
||||
}
|
||||
|
||||
result := p.shouldStartRecovery(tc.reason)
|
||||
assert.Equal(t, tc.expectedResult, result, tc.description)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_createRecoveryContext(t *testing.T) {
|
||||
p := newTestProg(t)
|
||||
|
||||
ctx, cleanup := p.createRecoveryContext()
|
||||
|
||||
// Verify context is created
|
||||
assert.NotNil(t, ctx)
|
||||
assert.NotNil(t, cleanup)
|
||||
|
||||
// Verify recoveryCancel is set
|
||||
p.recoveryCancelMu.Lock()
|
||||
assert.NotNil(t, p.recoveryCancel)
|
||||
p.recoveryCancelMu.Unlock()
|
||||
|
||||
// Test cleanup function
|
||||
cleanup()
|
||||
|
||||
// Verify recoveryCancel is cleared
|
||||
p.recoveryCancelMu.Lock()
|
||||
assert.Nil(t, p.recoveryCancel)
|
||||
p.recoveryCancelMu.Unlock()
|
||||
}
|
||||
|
||||
func Test_prepareForRecovery(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
reason RecoveryReason
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "regular failure",
|
||||
reason: RecoveryReasonRegularFailure,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "network change",
|
||||
reason: RecoveryReasonNetworkChange,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "OS failure",
|
||||
reason: RecoveryReasonOSFailure,
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
p := newTestProg(t)
|
||||
|
||||
err := p.prepareForRecovery(tc.reason)
|
||||
|
||||
if tc.wantErr {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
// Verify recoveryRunning is set to true
|
||||
assert.True(t, p.recoveryRunning.Load())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_completeRecovery(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
reason RecoveryReason
|
||||
recovered string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "regular failure recovery",
|
||||
reason: RecoveryReasonRegularFailure,
|
||||
recovered: "upstream1",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "network change recovery",
|
||||
reason: RecoveryReasonNetworkChange,
|
||||
recovered: "upstream2",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "OS failure recovery",
|
||||
reason: RecoveryReasonOSFailure,
|
||||
recovered: "upstream3",
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
p := newTestProg(t)
|
||||
|
||||
err := p.completeRecovery(tc.reason, tc.recovered)
|
||||
|
||||
if tc.wantErr {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
// Verify recoveryRunning is set to false
|
||||
assert.False(t, p.recoveryRunning.Load())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_reinitializeOSResolver(t *testing.T) {
|
||||
p := newTestProg(t)
|
||||
|
||||
err := p.reinitializeOSResolver("Test message")
|
||||
|
||||
// This function should not return an error under normal circumstances
|
||||
// The actual behavior depends on the OS resolver implementation
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func Test_handleRecovery_Integration(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
reason RecoveryReason
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "network change recovery",
|
||||
reason: RecoveryReasonNetworkChange,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "regular failure recovery",
|
||||
reason: RecoveryReasonRegularFailure,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "OS failure recovery",
|
||||
reason: RecoveryReasonOSFailure,
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
p := newTestProg(t)
|
||||
|
||||
// This is an integration test that exercises the full recovery flow
|
||||
// In a real test environment, you would mock the dependencies
|
||||
// For now, we're just testing that the method doesn't panic
|
||||
// and that the recovery logic flows correctly
|
||||
assert.NotPanics(t, func() {
|
||||
// Test only the preparation phase to avoid actual upstream checking
|
||||
if !p.shouldStartRecovery(tc.reason) {
|
||||
return
|
||||
}
|
||||
|
||||
_, cleanup := p.createRecoveryContext()
|
||||
defer cleanup()
|
||||
|
||||
if err := p.prepareForRecovery(tc.reason); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Skip the actual upstream recovery check for this test
|
||||
// as it requires properly configured upstreams
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// newTestProg creates a properly initialized *prog for testing.
|
||||
func newTestProg(t *testing.T) *prog {
|
||||
p := &prog{cfg: testhelper.SampleConfig(t)}
|
||||
p.logger.Store(mainLog.Load())
|
||||
p.um = newUpstreamMonitor(p.cfg, mainLog.Load())
|
||||
return p
|
||||
}
|
||||
18
cmd/cli/hostname.go
Normal file
18
cmd/cli/hostname.go
Normal file
@@ -0,0 +1,18 @@
|
||||
package cli
|
||||
|
||||
import "regexp"
|
||||
|
||||
// validHostname reports whether hostname is a valid hostname.
|
||||
// A valid hostname contains 3 -> 64 characters and conform to RFC1123.
|
||||
// This function validates hostnames to ensure they meet DNS naming standards
|
||||
// and prevents invalid hostnames from being used in DNS configurations
|
||||
func validHostname(hostname string) bool {
|
||||
hostnameLen := len(hostname)
|
||||
if hostnameLen < 3 || hostnameLen > 64 {
|
||||
return false
|
||||
}
|
||||
// RFC1123 regex pattern ensures hostnames follow DNS naming conventions
|
||||
// This prevents issues with DNS resolution and system compatibility
|
||||
validHostnameRfc1123 := regexp.MustCompile(`^(([a-zA-Z0-9]|[a-zA-Z0-9][a-zA-Z0-9\-]*[a-zA-Z0-9])\.)*([A-Za-z0-9]|[A-Za-z0-9][A-Za-z0-9\-]*[A-Za-z0-9])$`)
|
||||
return validHostnameRfc1123.MatchString(hostname)
|
||||
}
|
||||
35
cmd/cli/hostname_test.go
Normal file
35
cmd/cli/hostname_test.go
Normal file
@@ -0,0 +1,35 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func Test_validHostname(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
hostname string
|
||||
valid bool
|
||||
}{
|
||||
{"localhost", "localhost", true},
|
||||
{"localdomain", "localhost.localdomain", true},
|
||||
{"localhost6", "localhost6.localdomain6", true},
|
||||
{"ip6", "ip6-localhost", true},
|
||||
{"non-domain", "controld", true},
|
||||
{"domain", "controld.com", true},
|
||||
{"empty", "", false},
|
||||
{"min length", "fo", false},
|
||||
{"max length", strings.Repeat("a", 65), false},
|
||||
{"special char", "foo!", false},
|
||||
{"non-ascii", "fooΩ", false},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.hostname, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
assert.True(t, validHostname(tc.hostname) == tc.valid)
|
||||
})
|
||||
}
|
||||
}
|
||||
172
cmd/cli/http_log.go
Normal file
172
cmd/cli/http_log.go
Normal file
@@ -0,0 +1,172 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// HTTP log server endpoint constants
|
||||
const (
|
||||
httpLogEndpointPing = "/ping"
|
||||
httpLogEndpointLogs = "/logs"
|
||||
httpLogEndpointExit = "/exit"
|
||||
)
|
||||
|
||||
// httpLogClient sends logs to an HTTP server via POST requests.
|
||||
// This replaces the logConn functionality with HTTP-based communication.
|
||||
type httpLogClient struct {
|
||||
baseURL string
|
||||
client *http.Client
|
||||
}
|
||||
|
||||
// newHTTPLogClient creates a new HTTP log client
|
||||
func newHTTPLogClient(sockPath string) *httpLogClient {
|
||||
return &httpLogClient{
|
||||
baseURL: "http://unix",
|
||||
client: &http.Client{
|
||||
Transport: &http.Transport{
|
||||
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
return net.Dial("unix", sockPath)
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Write sends log data to the HTTP server via POST request
|
||||
func (hlc *httpLogClient) Write(b []byte) (int, error) {
|
||||
// Send log data via HTTP POST to /logs endpoint
|
||||
resp, err := hlc.client.Post(hlc.baseURL+httpLogEndpointLogs, "text/plain", bytes.NewReader(b))
|
||||
if err != nil {
|
||||
// Ignore errors to prevent log pollution, just like the original logConn
|
||||
return len(b), nil
|
||||
}
|
||||
resp.Body.Close()
|
||||
return len(b), nil
|
||||
}
|
||||
|
||||
// Ping tests if the HTTP log server is available
|
||||
func (hlc *httpLogClient) Ping() error {
|
||||
resp, err := hlc.client.Get(hlc.baseURL + httpLogEndpointPing)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
resp.Body.Close()
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close sends exit signal to the HTTP server
|
||||
func (hlc *httpLogClient) Close() error {
|
||||
// Send exit signal via HTTP POST with empty body
|
||||
resp, err := hlc.client.Post(hlc.baseURL+httpLogEndpointExit, "text/plain", bytes.NewReader([]byte{}))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
resp.Body.Close()
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetLogs retrieves all collected logs from the HTTP server
|
||||
func (hlc *httpLogClient) GetLogs() ([]byte, error) {
|
||||
resp, err := hlc.client.Get(hlc.baseURL + httpLogEndpointLogs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode == http.StatusNoContent {
|
||||
return []byte{}, nil
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("unexpected status code: %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
return io.ReadAll(resp.Body)
|
||||
}
|
||||
|
||||
// httpLogServer starts an HTTP server listening on unix socket to collect logs from runCmd.
|
||||
func httpLogServer(sockPath string, stopLogCh chan struct{}) error {
|
||||
addr, err := net.ResolveUnixAddr("unix", sockPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid log sock path: %w", err)
|
||||
}
|
||||
|
||||
ln, err := net.ListenUnix("unix", addr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not listen log socket: %w", err)
|
||||
}
|
||||
defer ln.Close()
|
||||
|
||||
// Create a log writer to store all logs
|
||||
logWriter := newLogWriter()
|
||||
|
||||
// Use a sync.Once to ensure channel is only closed once
|
||||
var channelClosed sync.Once
|
||||
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc(httpLogEndpointPing, func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodGet {
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
mux.HandleFunc(httpLogEndpointLogs, func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.Method {
|
||||
case http.MethodPost:
|
||||
// POST /logs - Store log data
|
||||
body, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
http.Error(w, "Failed to read request body", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Store log data in log writer
|
||||
logWriter.Write(body)
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
|
||||
case http.MethodGet:
|
||||
// GET /logs - Retrieve all logs
|
||||
// Get all logs from the log writer
|
||||
logWriter.mu.Lock()
|
||||
logs := logWriter.buf.Bytes()
|
||||
logWriter.mu.Unlock()
|
||||
|
||||
if len(logs) == 0 {
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "text/plain")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write(logs)
|
||||
|
||||
default:
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
}
|
||||
})
|
||||
|
||||
mux.HandleFunc(httpLogEndpointExit, func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
// Close the stop channel to signal completion (only once)
|
||||
channelClosed.Do(func() {
|
||||
close(stopLogCh)
|
||||
})
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
server := &http.Server{Handler: mux}
|
||||
return server.Serve(ln)
|
||||
}
|
||||
747
cmd/cli/http_log_test.go
Normal file
747
cmd/cli/http_log_test.go
Normal file
@@ -0,0 +1,747 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"golang.org/x/net/nettest"
|
||||
)
|
||||
|
||||
func unixDomainSocketPath(t *testing.T) string {
|
||||
t.Helper()
|
||||
sockPath, err := nettest.LocalPath()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temporary directory: %v", err)
|
||||
}
|
||||
return sockPath
|
||||
}
|
||||
|
||||
func TestHTTPLogServer(t *testing.T) {
|
||||
sockPath := unixDomainSocketPath(t)
|
||||
|
||||
// Create log channel
|
||||
stopLogCh := make(chan struct{})
|
||||
|
||||
// Start HTTP log server in a goroutine
|
||||
serverErr := make(chan error, 1)
|
||||
go func() {
|
||||
serverErr <- httpLogServer(sockPath, stopLogCh)
|
||||
}()
|
||||
|
||||
// Wait a bit for server to start
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Create HTTP client
|
||||
client := &http.Client{
|
||||
Transport: &http.Transport{
|
||||
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
return net.Dial("unix", sockPath)
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
t.Run("Ping endpoint", func(t *testing.T) {
|
||||
resp, err := client.Get("http://unix" + httpLogEndpointPing)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to ping server: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Errorf("Expected status 200, got %d", resp.StatusCode)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Ping endpoint wrong method", func(t *testing.T) {
|
||||
resp, err := client.Post("http://unix"+httpLogEndpointPing, "text/plain", bytes.NewReader([]byte("test")))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to send POST to ping: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusMethodNotAllowed {
|
||||
t.Errorf("Expected status 405, got %d", resp.StatusCode)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Log endpoint", func(t *testing.T) {
|
||||
testLog := "test log message"
|
||||
resp, err := client.Post("http://unix"+httpLogEndpointLogs, "text/plain", bytes.NewReader([]byte(testLog)))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to send log: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Errorf("Expected status 200, got %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
// Check if log was stored by retrieving it
|
||||
logsResp, err := client.Get("http://unix" + httpLogEndpointLogs)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get logs: %v", err)
|
||||
}
|
||||
defer logsResp.Body.Close()
|
||||
|
||||
if logsResp.StatusCode != http.StatusOK {
|
||||
t.Errorf("Expected status 200 for logs, got %d", logsResp.StatusCode)
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(logsResp.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to read logs: %v", err)
|
||||
}
|
||||
|
||||
if !strings.Contains(string(body), testLog) {
|
||||
t.Errorf("Expected log '%s' not found in stored logs", testLog)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Log endpoint wrong method", func(t *testing.T) {
|
||||
// Test unsupported method (PUT) on /logs endpoint
|
||||
req, err := http.NewRequest("PUT", "http://unix"+httpLogEndpointLogs, bytes.NewReader([]byte("test")))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create PUT request: %v", err)
|
||||
}
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to send PUT to logs: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusMethodNotAllowed {
|
||||
t.Errorf("Expected status 405, got %d", resp.StatusCode)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Exit endpoint", func(t *testing.T) {
|
||||
resp, err := client.Post("http://unix"+httpLogEndpointExit, "text/plain", bytes.NewReader([]byte{}))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to send exit: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Errorf("Expected status 200, got %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
// Check if channel is closed by trying to read from it
|
||||
select {
|
||||
case _, ok := <-stopLogCh:
|
||||
if ok {
|
||||
t.Error("Expected channel to be closed, but it's still open")
|
||||
}
|
||||
case <-time.After(1 * time.Second):
|
||||
t.Error("Timeout waiting for channel closure")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Exit endpoint wrong method", func(t *testing.T) {
|
||||
resp, err := client.Get("http://unix" + httpLogEndpointExit)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to send GET to exit: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusMethodNotAllowed {
|
||||
t.Errorf("Expected status 405, got %d", resp.StatusCode)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Multiple log messages", func(t *testing.T) {
|
||||
logs := []string{"log1", "log2", "log3"}
|
||||
|
||||
for _, log := range logs {
|
||||
resp, err := client.Post("http://unix"+httpLogEndpointLogs, "text/plain", bytes.NewReader([]byte(log+"\n")))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to send log '%s': %v", log, err)
|
||||
}
|
||||
resp.Body.Close()
|
||||
}
|
||||
|
||||
// Check if all logs were stored by retrieving them
|
||||
logsResp, err := client.Get("http://unix" + httpLogEndpointLogs)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get logs: %v", err)
|
||||
}
|
||||
defer logsResp.Body.Close()
|
||||
|
||||
if logsResp.StatusCode != http.StatusOK {
|
||||
t.Errorf("Expected status 200 for logs, got %d", logsResp.StatusCode)
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(logsResp.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to read logs: %v", err)
|
||||
}
|
||||
|
||||
logContent := string(body)
|
||||
for i, expectedLog := range logs {
|
||||
if !strings.Contains(logContent, expectedLog) {
|
||||
t.Errorf("Log %d: expected '%s' not found in stored logs", i, expectedLog)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Large log message", func(t *testing.T) {
|
||||
largeLog := strings.Repeat("a", 1024*10) // 10KB log message
|
||||
resp, err := client.Post("http://unix"+httpLogEndpointLogs, "text/plain", bytes.NewReader([]byte(largeLog)))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to send large log: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Errorf("Expected status 200, got %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
// Check if large log was stored by retrieving it
|
||||
logsResp, err := client.Get("http://unix" + httpLogEndpointLogs)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get logs: %v", err)
|
||||
}
|
||||
defer logsResp.Body.Close()
|
||||
|
||||
if logsResp.StatusCode != http.StatusOK {
|
||||
t.Errorf("Expected status 200 for logs, got %d", logsResp.StatusCode)
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(logsResp.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to read logs: %v", err)
|
||||
}
|
||||
|
||||
if !strings.Contains(string(body), largeLog) {
|
||||
t.Error("Large log message was not stored correctly")
|
||||
}
|
||||
})
|
||||
|
||||
// Clean up
|
||||
os.Remove(sockPath)
|
||||
}
|
||||
|
||||
func TestHTTPLogServerInvalidSocketPath(t *testing.T) {
|
||||
// Test with invalid socket path
|
||||
invalidPath := "/invalid/path/that/does/not/exist.sock"
|
||||
stopLogCh := make(chan struct{})
|
||||
|
||||
err := httpLogServer(invalidPath, stopLogCh)
|
||||
if err == nil {
|
||||
t.Error("Expected error for invalid socket path")
|
||||
}
|
||||
|
||||
if !strings.Contains(err.Error(), "could not listen log socket") {
|
||||
t.Errorf("Expected 'could not listen log socket' error, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHTTPLogServerSocketInUse(t *testing.T) {
|
||||
// Create a temporary socket path
|
||||
sockPath := unixDomainSocketPath(t)
|
||||
defer os.Remove(sockPath)
|
||||
|
||||
// Create the first server
|
||||
stopLogCh1 := make(chan struct{})
|
||||
serverErr1 := make(chan error, 1)
|
||||
go func() {
|
||||
serverErr1 <- httpLogServer(sockPath, stopLogCh1)
|
||||
}()
|
||||
|
||||
// Wait for first server to start
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Try to create a second server on the same socket
|
||||
stopLogCh2 := make(chan struct{})
|
||||
err := httpLogServer(sockPath, stopLogCh2)
|
||||
if err == nil {
|
||||
t.Error("Expected error when socket is already in use")
|
||||
}
|
||||
|
||||
if !strings.Contains(err.Error(), "could not listen log socket") {
|
||||
t.Errorf("Expected 'could not listen log socket' error, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHTTPLogServerConcurrentRequests(t *testing.T) {
|
||||
// Create a temporary socket path
|
||||
sockPath := unixDomainSocketPath(t)
|
||||
defer os.Remove(sockPath)
|
||||
|
||||
// Create log channel
|
||||
stopLogCh := make(chan struct{})
|
||||
|
||||
// Start HTTP log server in a goroutine
|
||||
serverErr := make(chan error, 1)
|
||||
go func() {
|
||||
serverErr <- httpLogServer(sockPath, stopLogCh)
|
||||
}()
|
||||
|
||||
// Wait for server to start
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Create HTTP client
|
||||
client := &http.Client{
|
||||
Transport: &http.Transport{
|
||||
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
return net.Dial("unix", sockPath)
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Send concurrent requests
|
||||
numRequests := 10
|
||||
done := make(chan bool, numRequests)
|
||||
|
||||
for i := 0; i < numRequests; i++ {
|
||||
go func(i int) {
|
||||
defer func() { done <- true }()
|
||||
|
||||
logMsg := fmt.Sprintf("concurrent log %d", i)
|
||||
resp, err := client.Post("http://unix"+httpLogEndpointLogs, "text/plain", bytes.NewReader([]byte(logMsg)))
|
||||
if err != nil {
|
||||
t.Errorf("Failed to send concurrent log %d: %v", i, err)
|
||||
return
|
||||
}
|
||||
resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Errorf("Expected status 200 for request %d, got %d", i, resp.StatusCode)
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
// Wait for all requests to complete
|
||||
for i := 0; i < numRequests; i++ {
|
||||
select {
|
||||
case <-done:
|
||||
// Request completed
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Errorf("Timeout waiting for concurrent request %d", i)
|
||||
}
|
||||
}
|
||||
|
||||
// Check if all logs were stored by retrieving them
|
||||
logsResp, err := client.Get("http://unix" + httpLogEndpointLogs)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get logs: %v", err)
|
||||
}
|
||||
defer logsResp.Body.Close()
|
||||
|
||||
if logsResp.StatusCode != http.StatusOK {
|
||||
t.Errorf("Expected status 200 for logs, got %d", logsResp.StatusCode)
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(logsResp.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to read logs: %v", err)
|
||||
}
|
||||
|
||||
logContent := string(body)
|
||||
// Verify all logs were stored
|
||||
for i := 0; i < numRequests; i++ {
|
||||
expectedLog := fmt.Sprintf("concurrent log %d", i)
|
||||
if !strings.Contains(logContent, expectedLog) {
|
||||
t.Errorf("Log '%s' was not stored", expectedLog)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestHTTPLogServerErrorHandling(t *testing.T) {
|
||||
// Create a temporary socket path
|
||||
sockPath := unixDomainSocketPath(t)
|
||||
defer os.Remove(sockPath)
|
||||
|
||||
// Create log channel
|
||||
stopLogCh := make(chan struct{})
|
||||
|
||||
// Start HTTP log server in a goroutine
|
||||
serverErr := make(chan error, 1)
|
||||
go func() {
|
||||
serverErr <- httpLogServer(sockPath, stopLogCh)
|
||||
}()
|
||||
|
||||
// Wait for server to start
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Create HTTP client
|
||||
client := &http.Client{
|
||||
Transport: &http.Transport{
|
||||
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
return net.Dial("unix", sockPath)
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
t.Run("Invalid request body", func(t *testing.T) {
|
||||
// Test with malformed request - this will fail at HTTP level, not server level
|
||||
// The server will return 400 Bad Request for invalid body
|
||||
resp, err := client.Post("http://unix"+httpLogEndpointLogs, "text/plain", strings.NewReader(""))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to send request: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// Empty body should still be processed successfully
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Errorf("Expected status 200, got %d", resp.StatusCode)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func BenchmarkHTTPLogServer(b *testing.B) {
|
||||
// Create a temporary socket path
|
||||
tmpDir := b.TempDir()
|
||||
sockPath := filepath.Join(tmpDir, "bench.sock")
|
||||
|
||||
// Create log channel
|
||||
stopLogCh := make(chan struct{})
|
||||
|
||||
// Start HTTP log server in a goroutine
|
||||
go func() {
|
||||
httpLogServer(sockPath, stopLogCh)
|
||||
}()
|
||||
|
||||
// Wait for server to start
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Create HTTP client
|
||||
client := &http.Client{
|
||||
Transport: &http.Transport{
|
||||
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
return net.Dial("unix", sockPath)
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Benchmark log sending
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
logMsg := fmt.Sprintf("benchmark log %d", i)
|
||||
resp, err := client.Post("http://unix"+httpLogEndpointLogs, "text/plain", bytes.NewReader([]byte(logMsg)))
|
||||
if err != nil {
|
||||
b.Fatalf("Failed to send log: %v", err)
|
||||
}
|
||||
resp.Body.Close()
|
||||
}
|
||||
|
||||
// Clean up
|
||||
os.Remove(sockPath)
|
||||
}
|
||||
|
||||
func TestHTTPLogClient(t *testing.T) {
|
||||
// Create a temporary socket path
|
||||
sockPath := unixDomainSocketPath(t)
|
||||
defer os.Remove(sockPath)
|
||||
|
||||
// Create log channel
|
||||
stopLogCh := make(chan struct{})
|
||||
|
||||
// Start HTTP log server in a goroutine
|
||||
serverErr := make(chan error, 1)
|
||||
go func() {
|
||||
serverErr <- httpLogServer(sockPath, stopLogCh)
|
||||
}()
|
||||
|
||||
// Wait for server to start
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Create HTTP log client
|
||||
client := newHTTPLogClient(sockPath)
|
||||
|
||||
t.Run("Ping server", func(t *testing.T) {
|
||||
err := client.Ping()
|
||||
if err != nil {
|
||||
t.Errorf("Ping failed: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Write logs", func(t *testing.T) {
|
||||
testLog := "test log message from client"
|
||||
n, err := client.Write([]byte(testLog))
|
||||
if err != nil {
|
||||
t.Errorf("Write failed: %v", err)
|
||||
}
|
||||
if n != len(testLog) {
|
||||
t.Errorf("Expected to write %d bytes, wrote %d", len(testLog), n)
|
||||
}
|
||||
|
||||
// Check if log was stored by retrieving it
|
||||
logs, err := client.GetLogs()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get logs: %v", err)
|
||||
}
|
||||
|
||||
if !strings.Contains(string(logs), testLog) {
|
||||
t.Errorf("Expected log '%s' not found in stored logs", testLog)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Close client", func(t *testing.T) {
|
||||
err := client.Close()
|
||||
if err != nil {
|
||||
t.Errorf("Close failed: %v", err)
|
||||
}
|
||||
|
||||
// Check if channel is closed (signaling completion)
|
||||
select {
|
||||
case _, ok := <-stopLogCh:
|
||||
if ok {
|
||||
t.Error("Expected channel to be closed, but it's still open")
|
||||
}
|
||||
case <-time.After(1 * time.Second):
|
||||
t.Error("Timeout waiting for channel closure")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestHTTPLogClientServerUnavailable(t *testing.T) {
|
||||
// Create client with non-existent socket
|
||||
sockPath := "/non/existent/socket.sock"
|
||||
client := newHTTPLogClient(sockPath)
|
||||
|
||||
t.Run("Ping unavailable server", func(t *testing.T) {
|
||||
err := client.Ping()
|
||||
if err == nil {
|
||||
t.Error("Expected ping to fail for unavailable server")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Write to unavailable server", func(t *testing.T) {
|
||||
testLog := "test log message"
|
||||
n, err := client.Write([]byte(testLog))
|
||||
if err != nil {
|
||||
t.Errorf("Write should not return error (ignores errors): %v", err)
|
||||
}
|
||||
if n != len(testLog) {
|
||||
t.Errorf("Expected to write %d bytes, wrote %d", len(testLog), n)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Close unavailable server", func(t *testing.T) {
|
||||
err := client.Close()
|
||||
if err == nil {
|
||||
t.Error("Expected close to fail for unavailable server")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func BenchmarkHTTPLogClient(b *testing.B) {
|
||||
// Create a temporary socket path
|
||||
tmpDir := b.TempDir()
|
||||
sockPath := filepath.Join(tmpDir, "bench.sock")
|
||||
|
||||
// Create log channel
|
||||
stopLogCh := make(chan struct{})
|
||||
|
||||
// Start HTTP log server in a goroutine
|
||||
go func() {
|
||||
httpLogServer(sockPath, stopLogCh)
|
||||
}()
|
||||
|
||||
// Wait for server to start
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Create HTTP log client
|
||||
client := newHTTPLogClient(sockPath)
|
||||
|
||||
// Benchmark client writes
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
logMsg := fmt.Sprintf("benchmark write %d", i)
|
||||
client.Write([]byte(logMsg))
|
||||
}
|
||||
|
||||
// Clean up
|
||||
os.Remove(sockPath)
|
||||
}
|
||||
|
||||
func TestHTTPLogServerWithLogWriter(t *testing.T) {
|
||||
// Create a temporary socket path
|
||||
sockPath := unixDomainSocketPath(t)
|
||||
defer os.Remove(sockPath)
|
||||
|
||||
// Create log channel
|
||||
stopLogCh := make(chan struct{})
|
||||
|
||||
// Start HTTP log server in a goroutine
|
||||
serverErr := make(chan error, 1)
|
||||
go func() {
|
||||
serverErr <- httpLogServer(sockPath, stopLogCh)
|
||||
}()
|
||||
|
||||
// Wait a bit for server to start
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Create HTTP client
|
||||
client := &http.Client{
|
||||
Transport: &http.Transport{
|
||||
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
return net.Dial("unix", sockPath)
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
t.Run("Store and retrieve logs", func(t *testing.T) {
|
||||
// Send multiple log messages
|
||||
logs := []string{"log message 1", "log message 2", "log message 3"}
|
||||
|
||||
for _, log := range logs {
|
||||
resp, err := client.Post("http://unix"+httpLogEndpointLogs, "text/plain", bytes.NewReader([]byte(log+"\n")))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to send log '%s': %v", log, err)
|
||||
}
|
||||
resp.Body.Close()
|
||||
}
|
||||
|
||||
// Retrieve all logs
|
||||
resp, err := client.Get("http://unix" + httpLogEndpointLogs)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get logs: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Errorf("Expected status 200, got %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to read logs response: %v", err)
|
||||
}
|
||||
|
||||
logContent := string(body)
|
||||
for _, log := range logs {
|
||||
if !strings.Contains(logContent, log) {
|
||||
t.Errorf("Expected log '%s' not found in retrieved logs", log)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Empty logs endpoint", func(t *testing.T) {
|
||||
// Create a new server for this test
|
||||
sockPath2 := unixDomainSocketPath(t)
|
||||
stopLogCh2 := make(chan struct{})
|
||||
|
||||
go func() {
|
||||
httpLogServer(sockPath2, stopLogCh2)
|
||||
}()
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
client2 := &http.Client{
|
||||
Transport: &http.Transport{
|
||||
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
return net.Dial("unix", sockPath2)
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
resp, err := client2.Get("http://unix" + httpLogEndpointLogs)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get logs: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusNoContent {
|
||||
t.Errorf("Expected status 204, got %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
os.Remove(sockPath2)
|
||||
})
|
||||
|
||||
t.Run("Channel closure on exit", func(t *testing.T) {
|
||||
// Send exit signal
|
||||
resp, err := client.Post("http://unix"+httpLogEndpointExit, "text/plain", bytes.NewReader([]byte{}))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to send exit: %v", err)
|
||||
}
|
||||
resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Errorf("Expected status 200, got %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
// Check if channel is closed by trying to read from it
|
||||
select {
|
||||
case _, ok := <-stopLogCh:
|
||||
if ok {
|
||||
t.Error("Expected channel to be closed, but it's still open")
|
||||
}
|
||||
case <-time.After(1 * time.Second):
|
||||
t.Error("Timeout waiting for channel closure")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestHTTPLogClientGetLogs(t *testing.T) {
|
||||
// Create a temporary socket path
|
||||
sockPath := unixDomainSocketPath(t)
|
||||
defer os.Remove(sockPath)
|
||||
|
||||
// Create log channel
|
||||
stopLogCh := make(chan struct{})
|
||||
|
||||
// Start HTTP log server in a goroutine
|
||||
go func() {
|
||||
httpLogServer(sockPath, stopLogCh)
|
||||
}()
|
||||
|
||||
// Wait a bit for server to start
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Create HTTP log client
|
||||
client := newHTTPLogClient(sockPath)
|
||||
|
||||
t.Run("Get logs from client", func(t *testing.T) {
|
||||
// Send some logs
|
||||
testLogs := []string{"client log 1", "client log 2", "client log 3"}
|
||||
for _, log := range testLogs {
|
||||
client.Write([]byte(log + "\n"))
|
||||
}
|
||||
|
||||
// Retrieve logs using client method
|
||||
logs, err := client.GetLogs()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get logs: %v", err)
|
||||
}
|
||||
|
||||
logContent := string(logs)
|
||||
for _, log := range testLogs {
|
||||
if !strings.Contains(logContent, log) {
|
||||
t.Errorf("Expected log '%s' not found in retrieved logs", log)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Get empty logs", func(t *testing.T) {
|
||||
// Create a new client for empty logs test
|
||||
sockPath2 := unixDomainSocketPath(t)
|
||||
stopLogCh2 := make(chan struct{})
|
||||
|
||||
go func() {
|
||||
httpLogServer(sockPath2, stopLogCh2)
|
||||
}()
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
client2 := newHTTPLogClient(sockPath2)
|
||||
logs, err := client2.GetLogs()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get empty logs: %v", err)
|
||||
}
|
||||
|
||||
if len(logs) != 0 {
|
||||
t.Errorf("Expected empty logs, got %d bytes", len(logs))
|
||||
}
|
||||
|
||||
os.Remove(sockPath2)
|
||||
})
|
||||
}
|
||||
111
cmd/cli/library.go
Normal file
111
cmd/cli/library.go
Normal file
@@ -0,0 +1,111 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
// AppCallback provides hooks for injecting certain functionalities
|
||||
// from mobile platforms to main ctrld cli.
|
||||
// This allows mobile applications to customize behavior without modifying core CLI code
|
||||
type AppCallback struct {
|
||||
HostName func() string
|
||||
LanIp func() string
|
||||
MacAddress func() string
|
||||
Exit func(error string)
|
||||
}
|
||||
|
||||
// AppConfig allows overwriting ctrld cli flags from mobile platforms.
|
||||
// This provides a clean interface for mobile apps to configure ctrld behavior
|
||||
type AppConfig struct {
|
||||
CdUID string
|
||||
ProvisionID string
|
||||
CustomHostname string
|
||||
HomeDir string
|
||||
UpstreamProto string
|
||||
Verbose int
|
||||
LogPath string
|
||||
}
|
||||
|
||||
// Network and HTTP configuration constants
|
||||
const (
|
||||
// defaultHTTPTimeout provides reasonable timeout for HTTP operations
|
||||
// This prevents hanging requests while allowing sufficient time for network delays
|
||||
defaultHTTPTimeout = 30 * time.Second
|
||||
|
||||
// defaultMaxRetries provides retry attempts for failed HTTP requests
|
||||
// This improves reliability in unstable network conditions
|
||||
defaultMaxRetries = 3
|
||||
|
||||
// downloadServerIp is the fallback IP for download operations
|
||||
// This ensures downloads work even when DNS resolution fails
|
||||
downloadServerIp = "23.171.240.151"
|
||||
)
|
||||
|
||||
// httpClientWithFallback returns an HTTP client configured with timeout and IPv4 fallback
|
||||
// This ensures reliable HTTP operations by preferring IPv4 and handling timeouts gracefully
|
||||
func httpClientWithFallback(timeout time.Duration) *http.Client {
|
||||
return &http.Client{
|
||||
Timeout: timeout,
|
||||
Transport: &http.Transport{
|
||||
// Prefer IPv4 over IPv6
|
||||
// This improves compatibility with networks that have IPv6 issues
|
||||
DialContext: (&net.Dialer{
|
||||
Timeout: 10 * time.Second,
|
||||
KeepAlive: 30 * time.Second,
|
||||
FallbackDelay: 1 * time.Millisecond, // Very small delay to prefer IPv4
|
||||
}).DialContext,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// doWithRetry performs an HTTP request with retries
|
||||
// This improves reliability by automatically retrying failed requests with exponential backoff
|
||||
func doWithRetry(req *http.Request, maxRetries int, ip string) (*http.Response, error) {
|
||||
var lastErr error
|
||||
client := httpClientWithFallback(defaultHTTPTimeout)
|
||||
var ipReq *http.Request
|
||||
if ip != "" {
|
||||
ipReq = req.Clone(req.Context())
|
||||
ipReq.Host = ip
|
||||
ipReq.URL.Host = ip
|
||||
}
|
||||
for attempt := 0; attempt < maxRetries; attempt++ {
|
||||
if attempt > 0 {
|
||||
// Linear backoff reduces server load and improves success rate
|
||||
time.Sleep(time.Second * time.Duration(attempt+1))
|
||||
}
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err == nil {
|
||||
return resp, nil
|
||||
}
|
||||
if ipReq != nil {
|
||||
mainLog.Load().Warn().Err(err).Msgf("Dial to %q failed", req.Host)
|
||||
mainLog.Load().Warn().Msgf("Fallback to direct ip to download prod version: %q", ip)
|
||||
resp, err = client.Do(ipReq)
|
||||
if err == nil {
|
||||
return resp, nil
|
||||
}
|
||||
}
|
||||
|
||||
lastErr = err
|
||||
mainLog.Load().Debug().Err(err).
|
||||
Str("method", req.Method).
|
||||
Str("url", req.URL.String()).
|
||||
Msgf("HTTP request attempt %d/%d failed", attempt+1, maxRetries)
|
||||
}
|
||||
return nil, fmt.Errorf("failed after %d attempts to %s %s: %v", maxRetries, req.Method, req.URL, lastErr)
|
||||
}
|
||||
|
||||
// Helper for making GET requests with retries
|
||||
// This provides a simplified interface for common GET operations with built-in retry logic
|
||||
func getWithRetry(url string, ip string) (*http.Response, error) {
|
||||
req, err := http.NewRequest(http.MethodGet, url, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return doWithRetry(req, defaultMaxRetries, ip)
|
||||
}
|
||||
407
cmd/cli/log_writer.go
Normal file
407
cmd/cli/log_writer.go
Normal file
@@ -0,0 +1,407 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"regexp"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"go.uber.org/zap"
|
||||
"go.uber.org/zap/zapcore"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld"
|
||||
)
|
||||
|
||||
// Log writer constants for buffer management and log formatting
|
||||
const (
|
||||
// logWriterSize is the default buffer size for log writers
|
||||
// This provides sufficient space for runtime logs without excessive memory usage
|
||||
logWriterSize = 1024 * 1024 * 5 // 5 MB
|
||||
|
||||
// logWriterSmallSize is used for memory-constrained environments
|
||||
// This reduces memory footprint while still maintaining log functionality
|
||||
logWriterSmallSize = 1024 * 1024 * 1 // 1 MB
|
||||
|
||||
// logWriterInitialSize is the initial buffer allocation
|
||||
// This provides immediate space for early log entries
|
||||
logWriterInitialSize = 32 * 1024 // 32 KB
|
||||
|
||||
// logWriterSentInterval controls how often logs are sent to external systems
|
||||
// This balances real-time logging with system performance
|
||||
logWriterSentInterval = time.Minute
|
||||
|
||||
// logWriterInitEndMarker marks the end of initialization logs
|
||||
// This helps separate startup logs from runtime logs
|
||||
logWriterInitEndMarker = "\n\n=== INIT_END ===\n\n"
|
||||
|
||||
// logWriterLogEndMarker marks the end of log sections
|
||||
// This provides clear boundaries for log parsing and analysis
|
||||
logWriterLogEndMarker = "\n\n=== LOG_END ===\n\n"
|
||||
)
|
||||
|
||||
// Custom level encoders that handle NOTICE level
|
||||
// Since NOTICE and WARN share the same numeric value (1), we handle them specially
|
||||
// in the encoder to display NOTICE messages with the "NOTICE" prefix.
|
||||
// Note: WARN messages will also display as "NOTICE" because they share the same level value.
|
||||
// This is the intended behavior for visual distinction.
|
||||
|
||||
// noticeLevelEncoder provides custom level encoding for NOTICE level
|
||||
// This ensures NOTICE messages are clearly distinguished from other log levels
|
||||
func noticeLevelEncoder(l zapcore.Level, enc zapcore.PrimitiveArrayEncoder) {
|
||||
switch l {
|
||||
case ctrld.NoticeLevel:
|
||||
enc.AppendString("NOTICE")
|
||||
default:
|
||||
zapcore.CapitalLevelEncoder(l, enc)
|
||||
}
|
||||
}
|
||||
|
||||
// noticeColorLevelEncoder provides colored level encoding for NOTICE level
|
||||
// This uses cyan color to make NOTICE messages visually distinct in terminal output
|
||||
func noticeColorLevelEncoder(l zapcore.Level, enc zapcore.PrimitiveArrayEncoder) {
|
||||
switch l {
|
||||
case ctrld.NoticeLevel:
|
||||
enc.AppendString("\x1b[36mNOTICE\x1b[0m") // Cyan color for NOTICE
|
||||
default:
|
||||
zapcore.CapitalColorLevelEncoder(l, enc)
|
||||
}
|
||||
}
|
||||
|
||||
// logViewResponse represents the response structure for log viewing requests
|
||||
// This provides a consistent JSON format for log data retrieval
|
||||
type logViewResponse struct {
|
||||
Data string `json:"data"`
|
||||
}
|
||||
|
||||
// logSentResponse represents the response structure for log sending operations
|
||||
// This includes size information and error details for debugging
|
||||
type logSentResponse struct {
|
||||
Size int64 `json:"size"`
|
||||
Error string `json:"error"`
|
||||
}
|
||||
|
||||
// logReader provides read access to log data with size information.
|
||||
//
|
||||
// This struct encapsulates log reading functionality for external consumers,
|
||||
// providing both the log content and metadata about the log size. It supports
|
||||
// reading from both internal log buffers (when no external logging is configured)
|
||||
// and external log files (when logging to file is enabled).
|
||||
//
|
||||
// Fields:
|
||||
// - r: An io.ReadCloser that provides access to the log content
|
||||
// - size: The total size of the log data in bytes
|
||||
//
|
||||
// The logReader is used by the control server to serve log content to clients
|
||||
// and by various CLI commands that need to display or process log data.
|
||||
type logReader struct {
|
||||
r io.ReadCloser
|
||||
size int64
|
||||
}
|
||||
|
||||
// logWriter is an internal buffer to keep track of runtime log when no logging is enabled.
|
||||
// This provides in-memory log storage for debugging and monitoring purposes
|
||||
type logWriter struct {
|
||||
mu sync.Mutex
|
||||
buf bytes.Buffer
|
||||
size int
|
||||
}
|
||||
|
||||
// newLogWriter creates an internal log writer.
|
||||
// This provides the default log writer with standard buffer size
|
||||
func newLogWriter() *logWriter {
|
||||
return newLogWriterWithSize(logWriterSize)
|
||||
}
|
||||
|
||||
// newSmallLogWriter creates an internal log writer with small buffer size.
|
||||
// This is used in memory-constrained environments or for temporary logging
|
||||
func newSmallLogWriter() *logWriter {
|
||||
return newLogWriterWithSize(logWriterSmallSize)
|
||||
}
|
||||
|
||||
// newLogWriterWithSize creates an internal log writer with a given buffer size.
|
||||
// This allows customization of log buffer size based on specific requirements
|
||||
func newLogWriterWithSize(size int) *logWriter {
|
||||
lw := &logWriter{size: size}
|
||||
return lw
|
||||
}
|
||||
|
||||
// Write implements io.Writer interface for logWriter
|
||||
// This manages buffer overflow by discarding old data while preserving important markers
|
||||
func (lw *logWriter) Write(p []byte) (int, error) {
|
||||
lw.mu.Lock()
|
||||
defer lw.mu.Unlock()
|
||||
|
||||
// If writing p causes overflows, discard old data.
|
||||
// This prevents unbounded memory growth while maintaining recent logs
|
||||
if lw.buf.Len()+len(p) > lw.size {
|
||||
buf := lw.buf.Bytes()
|
||||
haveEndMarker := false
|
||||
// If there's init end marker already, preserve the data til the marker.
|
||||
// This ensures initialization logs are always available for debugging
|
||||
if idx := bytes.LastIndex(buf, []byte(logWriterInitEndMarker)); idx >= 0 {
|
||||
buf = buf[:idx+len(logWriterInitEndMarker)]
|
||||
haveEndMarker = true
|
||||
} else {
|
||||
// Otherwise, preserve the initial size data.
|
||||
buf = buf[:logWriterInitialSize]
|
||||
if idx := bytes.LastIndex(buf, []byte("\n")); idx != -1 {
|
||||
buf = buf[:idx]
|
||||
}
|
||||
}
|
||||
lw.buf.Reset()
|
||||
lw.buf.Write(buf)
|
||||
if !haveEndMarker {
|
||||
lw.buf.WriteString(logWriterInitEndMarker) // indicate that the log was truncated.
|
||||
}
|
||||
}
|
||||
// If p is bigger than buffer size, truncate p by half until its size is smaller.
|
||||
for len(p)+lw.buf.Len() > lw.size {
|
||||
p = p[len(p)/2:]
|
||||
}
|
||||
return lw.buf.Write(p)
|
||||
}
|
||||
|
||||
// initLogging initializes global logging setup.
|
||||
func (p *prog) initLogging(backup bool) {
|
||||
logCores := initLoggingWithBackup(backup)
|
||||
|
||||
// Initializing internal logging after global logging.
|
||||
p.initInternalLogging(logCores)
|
||||
p.logger.Store(mainLog.Load())
|
||||
}
|
||||
|
||||
// initInternalLogging performs internal logging if there's no log enabled.
|
||||
func (p *prog) initInternalLogging(externalCores []zapcore.Core) {
|
||||
if !p.needInternalLogging() {
|
||||
return
|
||||
}
|
||||
p.initInternalLogWriterOnce.Do(func() {
|
||||
p.Notice().Msg("Internal logging enabled")
|
||||
p.internalLogWriter = newLogWriter()
|
||||
p.internalLogSent = time.Now().Add(-logWriterSentInterval)
|
||||
p.internalWarnLogWriter = newSmallLogWriter()
|
||||
})
|
||||
p.mu.Lock()
|
||||
lw := p.internalLogWriter
|
||||
wlw := p.internalWarnLogWriter
|
||||
p.mu.Unlock()
|
||||
|
||||
// Create zap cores for different writers
|
||||
var cores []zapcore.Core
|
||||
cores = append(cores, externalCores...)
|
||||
|
||||
// Add core for internal log writer.
|
||||
// Run the internal logging at debug level, so we could
|
||||
// have enough information for troubleshooting.
|
||||
internalCore := newHumanReadableZapCore(lw, zapcore.DebugLevel)
|
||||
cores = append(cores, internalCore)
|
||||
|
||||
// Add core for internal warn log writer
|
||||
warnCore := newHumanReadableZapCore(wlw, zapcore.WarnLevel)
|
||||
cores = append(cores, warnCore)
|
||||
|
||||
// Create a multi-core logger
|
||||
multiCore := zapcore.NewTee(cores...)
|
||||
logger := zap.New(multiCore)
|
||||
|
||||
mainLog.Store(&ctrld.Logger{Logger: logger})
|
||||
}
|
||||
|
||||
// needInternalLogging reports whether prog needs to run internal logging.
|
||||
func (p *prog) needInternalLogging() bool {
|
||||
// Do not run in non-cd mode.
|
||||
if cdUID == "" {
|
||||
return false
|
||||
}
|
||||
// Do not run if there's already log file.
|
||||
if p.cfg.Service.LogPath != "" {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// logReaderNoColor returns a logReader with ANSI color codes stripped from the log content.
|
||||
//
|
||||
// This method is useful when log content needs to be processed by tools that don't
|
||||
// handle ANSI escape sequences properly, or when storing logs in plain text format.
|
||||
// It internally calls logReader(true) to strip color codes.
|
||||
//
|
||||
// Returns:
|
||||
// - *logReader: A logReader instance with color codes removed, or nil if no logs available
|
||||
// - error: Any error encountered during log reading (e.g., empty logs, file access issues)
|
||||
//
|
||||
// Use cases:
|
||||
// - Log processing pipelines that require plain text
|
||||
// - Storing logs in databases or text files
|
||||
// - Displaying logs in environments that don't support color
|
||||
func (p *prog) logReaderNoColor() (*logReader, error) {
|
||||
return p.logReader(true)
|
||||
}
|
||||
|
||||
// logReaderRaw returns a logReader with ANSI color codes preserved in the log content.
|
||||
//
|
||||
// This method maintains the original formatting of log entries including color codes,
|
||||
// which is useful for displaying logs in terminals that support ANSI colors or when
|
||||
// the original visual formatting needs to be preserved. It internally calls logReader(false).
|
||||
//
|
||||
// Returns:
|
||||
// - *logReader: A logReader instance with color codes preserved, or nil if no logs available
|
||||
// - error: Any error encountered during log reading (e.g., empty logs, file access issues)
|
||||
//
|
||||
// Use cases:
|
||||
// - Terminal-based log viewers that support color
|
||||
// - Interactive debugging sessions
|
||||
// - Preserving original log formatting for display
|
||||
func (p *prog) logReaderRaw() (*logReader, error) {
|
||||
return p.logReader(false)
|
||||
}
|
||||
|
||||
// logReader creates a logReader instance for accessing log content with optional color stripping.
|
||||
//
|
||||
// This is the core method that handles log reading from different sources based on the
|
||||
// current logging configuration. It supports both internal logging (when no external
|
||||
// logging is configured) and external file logging (when logging to file is enabled).
|
||||
//
|
||||
// Behavior:
|
||||
// - Internal logging: Reads from internal log buffers (normal logs + warning logs)
|
||||
// and combines them with appropriate markers for separation
|
||||
// - External logging: Reads directly from the configured log file
|
||||
// - Empty logs: Returns appropriate error messages when no log content is available
|
||||
//
|
||||
// Parameters:
|
||||
// - stripColor: If true, removes ANSI color codes from log content; if false, preserves them
|
||||
//
|
||||
// Returns:
|
||||
// - *logReader: A logReader instance providing access to log content and size metadata
|
||||
// - error: Any error encountered during log reading, including:
|
||||
// - "nil internal log writer" - Internal logging not properly initialized
|
||||
// - "nil internal warn log writer" - Warning log writer not properly initialized
|
||||
// - "internal log is empty" - No content in internal log buffers
|
||||
// - "log file is empty" - External log file exists but contains no data
|
||||
// - File system errors when accessing external log files
|
||||
//
|
||||
// The method handles thread-safe access to internal log buffers and provides
|
||||
// comprehensive error handling for various edge cases.
|
||||
func (p *prog) logReader(stripColor bool) (*logReader, error) {
|
||||
if p.needInternalLogging() {
|
||||
p.mu.Lock()
|
||||
lw := p.internalLogWriter
|
||||
wlw := p.internalWarnLogWriter
|
||||
p.mu.Unlock()
|
||||
if lw == nil {
|
||||
return nil, errors.New("nil internal log writer")
|
||||
}
|
||||
if wlw == nil {
|
||||
return nil, errors.New("nil internal warn log writer")
|
||||
}
|
||||
|
||||
// Normal log content.
|
||||
lw.mu.Lock()
|
||||
lwReader := newLogReader(&lw.buf, stripColor)
|
||||
lwSize := lw.buf.Len()
|
||||
lw.mu.Unlock()
|
||||
// Warn log content.
|
||||
wlw.mu.Lock()
|
||||
wlwReader := newLogReader(&wlw.buf, stripColor)
|
||||
wlwSize := wlw.buf.Len()
|
||||
wlw.mu.Unlock()
|
||||
reader := io.MultiReader(lwReader, bytes.NewReader([]byte(logWriterLogEndMarker)), wlwReader)
|
||||
lr := &logReader{r: io.NopCloser(reader)}
|
||||
lr.size = int64(lwSize + wlwSize)
|
||||
if lr.size == 0 {
|
||||
return nil, errors.New("internal log is empty")
|
||||
}
|
||||
return lr, nil
|
||||
}
|
||||
if p.cfg.Service.LogPath == "" {
|
||||
return &logReader{r: io.NopCloser(strings.NewReader(""))}, nil
|
||||
}
|
||||
f, err := os.Open(normalizeLogFilePath(p.cfg.Service.LogPath))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
lr := &logReader{r: f}
|
||||
if st, err := f.Stat(); err == nil {
|
||||
lr.size = st.Size()
|
||||
} else {
|
||||
return nil, fmt.Errorf("f.Stat: %w", err)
|
||||
}
|
||||
if lr.size == 0 {
|
||||
return nil, errors.New("log file is empty")
|
||||
}
|
||||
return lr, nil
|
||||
}
|
||||
|
||||
// newHumanReadableZapCore creates a zap core optimized for human-readable log output.
|
||||
//
|
||||
// Features:
|
||||
// - Uses development encoder configuration for enhanced readability
|
||||
// - Console encoding with colored log levels for easy visual scanning
|
||||
// - Millisecond precision timestamps in human-friendly format
|
||||
// - Structured field output with clear key-value pairs
|
||||
// - Ideal for development, debugging, and interactive terminal sessions
|
||||
//
|
||||
// Parameters:
|
||||
// - w: The output writer (e.g., os.Stdout, file, buffer)
|
||||
// - level: Minimum log level to capture (e.g., Debug, Info, Warn, Error)
|
||||
//
|
||||
// Returns a zapcore.Core configured for human consumption.
|
||||
func newHumanReadableZapCore(w io.Writer, level zapcore.Level) zapcore.Core {
|
||||
encoderConfig := zap.NewDevelopmentEncoderConfig()
|
||||
encoderConfig.TimeKey = "time"
|
||||
encoderConfig.EncodeTime = zapcore.TimeEncoderOfLayout(time.StampMilli)
|
||||
encoderConfig.EncodeLevel = noticeColorLevelEncoder
|
||||
encoder := zapcore.NewConsoleEncoder(encoderConfig)
|
||||
return zapcore.NewCore(encoder, zapcore.AddSync(w), level)
|
||||
}
|
||||
|
||||
// newMachineFriendlyZapCore creates a zap core optimized for machine processing and log aggregation.
|
||||
//
|
||||
// Features:
|
||||
// - Uses production encoder configuration for consistent, parseable output
|
||||
// - Console encoding with non-colored log levels for log parsing tools
|
||||
// - Millisecond precision timestamps in ISO-like format
|
||||
// - Structured field output optimized for log aggregation systems
|
||||
// - Ideal for production environments, log shipping, and automated analysis
|
||||
//
|
||||
// Parameters:
|
||||
// - w: The output writer (e.g., os.Stdout, file, buffer)
|
||||
// - level: Minimum log level to capture (e.g., Debug, Info, Warn, Error)
|
||||
//
|
||||
// Returns a zapcore.Core configured for machine consumption and log aggregation.
|
||||
func newMachineFriendlyZapCore(w io.Writer, level zapcore.Level) zapcore.Core {
|
||||
encoderConfig := zap.NewProductionEncoderConfig()
|
||||
encoderConfig.TimeKey = "time"
|
||||
encoderConfig.EncodeTime = zapcore.TimeEncoderOfLayout(time.StampMilli)
|
||||
encoderConfig.EncodeLevel = noticeLevelEncoder
|
||||
encoder := zapcore.NewConsoleEncoder(encoderConfig)
|
||||
return zapcore.NewCore(encoder, zapcore.AddSync(w), level)
|
||||
}
|
||||
|
||||
// ansiRegex is a regular expression to match ANSI color codes.
|
||||
var ansiRegex = regexp.MustCompile(`\x1b\[[0-9;]*m`)
|
||||
|
||||
// newLogReader creates a reader for log buffer content with optional ANSI color stripping.
|
||||
//
|
||||
// This function provides flexible log content access by allowing consumers to choose
|
||||
// between raw log data (with ANSI color codes) or stripped content (without color codes).
|
||||
// The color stripping is useful when logs need to be processed by tools that don't
|
||||
// handle ANSI escape sequences properly, or when storing logs in plain text format.
|
||||
//
|
||||
// Parameters:
|
||||
// - buf: The log buffer containing the log data to read
|
||||
// - stripColor: If true, strips ANSI color codes from the log content;
|
||||
// if false, returns raw log content with color codes preserved
|
||||
//
|
||||
// Returns an io.Reader that provides access to the processed log content.
|
||||
func newLogReader(buf *bytes.Buffer, stripColor bool) io.Reader {
|
||||
if stripColor {
|
||||
return strings.NewReader(ansiRegex.ReplaceAllString(buf.String(), ""))
|
||||
}
|
||||
return strings.NewReader(buf.String())
|
||||
}
|
||||
417
cmd/cli/log_writer_test.go
Normal file
417
cmd/cli/log_writer_test.go
Normal file
@@ -0,0 +1,417 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"go.uber.org/zap"
|
||||
"go.uber.org/zap/zapcore"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld"
|
||||
)
|
||||
|
||||
func Test_logWriter_Write(t *testing.T) {
|
||||
size := 64 * 1024
|
||||
lw := &logWriter{size: size}
|
||||
lw.buf.Grow(lw.size)
|
||||
data := strings.Repeat("A", size)
|
||||
lw.Write([]byte(data))
|
||||
if lw.buf.String() != data {
|
||||
t.Fatalf("unexpected buf content: %v", lw.buf.String())
|
||||
}
|
||||
newData := "B"
|
||||
halfData := strings.Repeat("A", len(data)/2) + logWriterInitEndMarker
|
||||
lw.Write([]byte(newData))
|
||||
if lw.buf.String() != halfData+newData {
|
||||
t.Fatalf("unexpected new buf content: %v", lw.buf.String())
|
||||
}
|
||||
|
||||
bigData := strings.Repeat("B", 256*1024)
|
||||
expected := halfData + strings.Repeat("B", 16*1024)
|
||||
lw.Write([]byte(bigData))
|
||||
if lw.buf.String() != expected {
|
||||
t.Fatalf("unexpected big buf content: %v", lw.buf.String())
|
||||
}
|
||||
}
|
||||
|
||||
func Test_logWriter_ConcurrentWrite(t *testing.T) {
|
||||
size := 64 * 1024
|
||||
lw := &logWriter{size: size}
|
||||
n := 10
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(n)
|
||||
for i := 0; i < n; i++ {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
lw.Write([]byte(strings.Repeat("A", i)))
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
if lw.buf.Len() > lw.size {
|
||||
t.Fatalf("unexpected buf size: %v, content: %q", lw.buf.Len(), lw.buf.String())
|
||||
}
|
||||
}
|
||||
|
||||
func Test_logWriter_MarkerInitEnd(t *testing.T) {
|
||||
size := 64 * 1024
|
||||
lw := &logWriter{size: size}
|
||||
lw.buf.Grow(lw.size)
|
||||
|
||||
paddingSize := 10
|
||||
// Writing half of the size, minus len(end marker) and padding size.
|
||||
dataSize := size/2 - len(logWriterInitEndMarker) - paddingSize
|
||||
data := strings.Repeat("A", dataSize)
|
||||
// Inserting newline for making partial init data
|
||||
data += "\n"
|
||||
// Filling left over buffer to make the log full.
|
||||
// The data length: len(end marker) + padding size - 1 (for newline above) + size/2
|
||||
data += strings.Repeat("A", len(logWriterInitEndMarker)+paddingSize-1+(size/2))
|
||||
lw.Write([]byte(data))
|
||||
if lw.buf.String() != data {
|
||||
t.Fatalf("unexpected buf content: %v", lw.buf.String())
|
||||
}
|
||||
lw.Write([]byte("B"))
|
||||
lw.Write([]byte(strings.Repeat("B", 256*1024)))
|
||||
firstIdx := strings.Index(lw.buf.String(), logWriterInitEndMarker)
|
||||
lastIdx := strings.LastIndex(lw.buf.String(), logWriterInitEndMarker)
|
||||
// Check if init end marker present.
|
||||
if firstIdx == -1 || lastIdx == -1 {
|
||||
t.Fatalf("missing init end marker: %s", lw.buf.String())
|
||||
}
|
||||
// Check if init end marker appears only once.
|
||||
if firstIdx != lastIdx {
|
||||
t.Fatalf("log init end marker appears more than once: %s", lw.buf.String())
|
||||
}
|
||||
// Ensure that we have the correct init log data.
|
||||
if !strings.Contains(lw.buf.String(), strings.Repeat("A", dataSize)+logWriterInitEndMarker) {
|
||||
t.Fatalf("unexpected log content: %s", lw.buf.String())
|
||||
}
|
||||
}
|
||||
|
||||
// TestNoticeLevel tests that the custom NOTICE level works correctly
|
||||
func TestNoticeLevel(t *testing.T) {
|
||||
// Create a buffer to capture log output
|
||||
var buf bytes.Buffer
|
||||
|
||||
// Create encoder config with custom NOTICE level support
|
||||
encoderConfig := zap.NewDevelopmentEncoderConfig()
|
||||
encoderConfig.TimeKey = "time"
|
||||
encoderConfig.EncodeTime = zapcore.TimeEncoderOfLayout("15:04:05.000")
|
||||
encoderConfig.EncodeLevel = noticeLevelEncoder
|
||||
|
||||
// Test with NOTICE level
|
||||
encoder := zapcore.NewConsoleEncoder(encoderConfig)
|
||||
core := zapcore.NewCore(encoder, zapcore.AddSync(&buf), ctrld.NoticeLevel)
|
||||
logger := zap.New(core)
|
||||
ctrldLogger := &ctrld.Logger{Logger: logger}
|
||||
|
||||
// Log messages at different levels
|
||||
ctrldLogger.Debug().Msg("This is a DEBUG message")
|
||||
ctrldLogger.Info().Msg("This is an INFO message")
|
||||
ctrldLogger.Notice().Msg("This is a NOTICE message")
|
||||
ctrldLogger.Warn().Msg("This is a WARN message")
|
||||
ctrldLogger.Error().Msg("This is an ERROR message")
|
||||
|
||||
output := buf.String()
|
||||
|
||||
// Verify that DEBUG and INFO messages are NOT logged (filtered out)
|
||||
if strings.Contains(output, "DEBUG") {
|
||||
t.Error("DEBUG message should not be logged when level is NOTICE")
|
||||
}
|
||||
if strings.Contains(output, "INFO") {
|
||||
t.Error("INFO message should not be logged when level is NOTICE")
|
||||
}
|
||||
|
||||
// Verify that NOTICE, WARN, and ERROR messages ARE logged
|
||||
if !strings.Contains(output, "NOTICE") {
|
||||
t.Error("NOTICE message should be logged when level is NOTICE")
|
||||
}
|
||||
if !strings.Contains(output, "WARN") {
|
||||
t.Error("WARN message should be logged when level is NOTICE")
|
||||
}
|
||||
if !strings.Contains(output, "ERROR") {
|
||||
t.Error("ERROR message should be logged when level is NOTICE")
|
||||
}
|
||||
|
||||
// Verify the NOTICE message content
|
||||
if !strings.Contains(output, "This is a NOTICE message") {
|
||||
t.Error("NOTICE message content should be present")
|
||||
}
|
||||
|
||||
t.Logf("Log output with NOTICE level:\n%s", output)
|
||||
}
|
||||
|
||||
func TestNewLogReader(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
bufContent string
|
||||
stripColor bool
|
||||
expected string
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "empty_buffer_no_color_strip",
|
||||
bufContent: "",
|
||||
stripColor: false,
|
||||
expected: "",
|
||||
description: "Empty buffer should return empty reader",
|
||||
},
|
||||
{
|
||||
name: "empty_buffer_with_color_strip",
|
||||
bufContent: "",
|
||||
stripColor: true,
|
||||
expected: "",
|
||||
description: "Empty buffer with color strip should return empty reader",
|
||||
},
|
||||
{
|
||||
name: "plain_text_no_color_strip",
|
||||
bufContent: "This is plain text without any color codes",
|
||||
stripColor: false,
|
||||
expected: "This is plain text without any color codes",
|
||||
description: "Plain text should be returned as-is when not stripping colors",
|
||||
},
|
||||
{
|
||||
name: "plain_text_with_color_strip",
|
||||
bufContent: "This is plain text without any color codes",
|
||||
stripColor: true,
|
||||
expected: "This is plain text without any color codes",
|
||||
description: "Plain text should be returned as-is when stripping colors",
|
||||
},
|
||||
{
|
||||
name: "text_with_ansi_codes_no_strip",
|
||||
bufContent: "Normal text \x1b[31mred text\x1b[0m normal again",
|
||||
stripColor: false,
|
||||
expected: "Normal text \x1b[31mred text\x1b[0m normal again",
|
||||
description: "ANSI color codes should be preserved when not stripping",
|
||||
},
|
||||
{
|
||||
name: "text_with_ansi_codes_with_strip",
|
||||
bufContent: "Normal text \x1b[31mred text\x1b[0m normal again",
|
||||
stripColor: true,
|
||||
expected: "Normal text red text normal again",
|
||||
description: "ANSI color codes should be removed when stripping colors",
|
||||
},
|
||||
{
|
||||
name: "multiple_ansi_codes_no_strip",
|
||||
bufContent: "\x1b[1mBold\x1b[0m \x1b[32mGreen\x1b[0m \x1b[34mBlue\x1b[0m text",
|
||||
stripColor: false,
|
||||
expected: "\x1b[1mBold\x1b[0m \x1b[32mGreen\x1b[0m \x1b[34mBlue\x1b[0m text",
|
||||
description: "Multiple ANSI codes should be preserved when not stripping",
|
||||
},
|
||||
{
|
||||
name: "multiple_ansi_codes_with_strip",
|
||||
bufContent: "\x1b[1mBold\x1b[0m \x1b[32mGreen\x1b[0m \x1b[34mBlue\x1b[0m text",
|
||||
stripColor: true,
|
||||
expected: "Bold Green Blue text",
|
||||
description: "Multiple ANSI codes should be removed when stripping colors",
|
||||
},
|
||||
{
|
||||
name: "complex_ansi_sequences_no_strip",
|
||||
bufContent: "\x1b[1;31;42mBold red on green\x1b[0m \x1b[38;5;208mOrange\x1b[0m",
|
||||
stripColor: false,
|
||||
expected: "\x1b[1;31;42mBold red on green\x1b[0m \x1b[38;5;208mOrange\x1b[0m",
|
||||
description: "Complex ANSI sequences should be preserved when not stripping",
|
||||
},
|
||||
{
|
||||
name: "complex_ansi_sequences_with_strip",
|
||||
bufContent: "\x1b[1;31;42mBold red on green\x1b[0m \x1b[38;5;208mOrange\x1b[0m",
|
||||
stripColor: true,
|
||||
expected: "Bold red on green Orange",
|
||||
description: "Complex ANSI sequences should be removed when stripping colors",
|
||||
},
|
||||
{
|
||||
name: "ansi_codes_with_newlines_no_strip",
|
||||
bufContent: "Line 1\n\x1b[31mRed line\x1b[0m\nLine 3",
|
||||
stripColor: false,
|
||||
expected: "Line 1\n\x1b[31mRed line\x1b[0m\nLine 3",
|
||||
description: "ANSI codes with newlines should be preserved when not stripping",
|
||||
},
|
||||
{
|
||||
name: "ansi_codes_with_newlines_with_strip",
|
||||
bufContent: "Line 1\n\x1b[31mRed line\x1b[0m\nLine 3",
|
||||
stripColor: true,
|
||||
expected: "Line 1\nRed line\nLine 3",
|
||||
description: "ANSI codes with newlines should be removed when stripping colors",
|
||||
},
|
||||
{
|
||||
name: "malformed_ansi_codes_no_strip",
|
||||
bufContent: "Text \x1b[invalidm \x1b[0m normal",
|
||||
stripColor: false,
|
||||
expected: "Text \x1b[invalidm \x1b[0m normal",
|
||||
description: "Malformed ANSI codes should be preserved when not stripping",
|
||||
},
|
||||
{
|
||||
name: "malformed_ansi_codes_with_strip",
|
||||
bufContent: "Text \x1b[invalidm \x1b[0m normal",
|
||||
stripColor: true,
|
||||
expected: "Text \x1b[invalidm normal",
|
||||
description: "Non-matching ANSI sequences should be preserved when stripping colors",
|
||||
},
|
||||
{
|
||||
name: "large_buffer_no_strip",
|
||||
bufContent: strings.Repeat("A", 10000) + "\x1b[31m" + strings.Repeat("B", 1000) + "\x1b[0m",
|
||||
stripColor: false,
|
||||
expected: strings.Repeat("A", 10000) + "\x1b[31m" + strings.Repeat("B", 1000) + "\x1b[0m",
|
||||
description: "Large buffer should handle ANSI codes correctly when not stripping",
|
||||
},
|
||||
{
|
||||
name: "large_buffer_with_strip",
|
||||
bufContent: strings.Repeat("A", 10000) + "\x1b[31m" + strings.Repeat("B", 1000) + "\x1b[0m",
|
||||
stripColor: true,
|
||||
expected: strings.Repeat("A", 10000) + strings.Repeat("B", 1000),
|
||||
description: "Large buffer should remove ANSI codes correctly when stripping",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Create a buffer with the test content
|
||||
buf := &bytes.Buffer{}
|
||||
buf.WriteString(tt.bufContent)
|
||||
|
||||
// Create the log reader
|
||||
reader := newLogReader(buf, tt.stripColor)
|
||||
|
||||
// Read all content from the reader
|
||||
content, err := io.ReadAll(reader)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to read from log reader: %v", err)
|
||||
}
|
||||
|
||||
// Verify the content matches expected
|
||||
actual := string(content)
|
||||
if actual != tt.expected {
|
||||
t.Errorf("Expected content: %q, got: %q", tt.expected, actual)
|
||||
t.Logf("Description: %s", tt.description)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewLogReader_ReaderBehavior(t *testing.T) {
|
||||
// Test that the returned reader behaves correctly
|
||||
buf := &bytes.Buffer{}
|
||||
buf.WriteString("Test content with \x1b[31mred\x1b[0m text")
|
||||
|
||||
// Test with color stripping
|
||||
reader := newLogReader(buf, true)
|
||||
|
||||
// Test reading in chunks
|
||||
chunk1 := make([]byte, 10)
|
||||
n1, err := reader.Read(chunk1)
|
||||
if err != nil && err != io.EOF {
|
||||
t.Fatalf("Unexpected error reading first chunk: %v", err)
|
||||
}
|
||||
if n1 != 10 {
|
||||
t.Errorf("Expected to read 10 bytes, got %d", n1)
|
||||
}
|
||||
|
||||
// Test reading remaining content
|
||||
remaining, err := io.ReadAll(reader)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to read remaining content: %v", err)
|
||||
}
|
||||
|
||||
// Verify total content
|
||||
totalContent := string(chunk1[:n1]) + string(remaining)
|
||||
expected := "Test content with red text"
|
||||
if totalContent != expected {
|
||||
t.Errorf("Expected total content: %q, got: %q", expected, totalContent)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewLogReader_ConcurrentAccess(t *testing.T) {
|
||||
// Test concurrent access to the same buffer
|
||||
buf := &bytes.Buffer{}
|
||||
buf.WriteString("Concurrent test with \x1b[32mgreen\x1b[0m text")
|
||||
|
||||
var wg sync.WaitGroup
|
||||
numGoroutines := 10
|
||||
results := make(chan string, numGoroutines)
|
||||
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
reader := newLogReader(buf, true)
|
||||
content, err := io.ReadAll(reader)
|
||||
if err != nil {
|
||||
t.Errorf("Failed to read content: %v", err)
|
||||
return
|
||||
}
|
||||
results <- string(content)
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
close(results)
|
||||
|
||||
// Verify all goroutines got the same result
|
||||
expected := "Concurrent test with green text"
|
||||
for result := range results {
|
||||
if result != expected {
|
||||
t.Errorf("Expected: %q, got: %q", expected, result)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewLogReader_ANSIRegexEdgeCases(t *testing.T) {
|
||||
// Test edge cases for ANSI regex matching
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "empty_escape_sequence",
|
||||
input: "Text \x1b[m normal",
|
||||
expected: "Text normal",
|
||||
},
|
||||
{
|
||||
name: "multiple_semicolons",
|
||||
input: "Text \x1b[1;2;3;4m normal",
|
||||
expected: "Text normal",
|
||||
},
|
||||
{
|
||||
name: "numeric_only",
|
||||
input: "Text \x1b[123m normal",
|
||||
expected: "Text normal",
|
||||
},
|
||||
{
|
||||
name: "mixed_numeric_semicolon",
|
||||
input: "Text \x1b[1;23;456m normal",
|
||||
expected: "Text normal",
|
||||
},
|
||||
{
|
||||
name: "no_closing_bracket",
|
||||
input: "Text \x1b[31 normal",
|
||||
expected: "Text \x1b[31 normal",
|
||||
},
|
||||
{
|
||||
name: "no_opening_bracket",
|
||||
input: "Text 31m normal",
|
||||
expected: "Text 31m normal",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
buf := &bytes.Buffer{}
|
||||
buf.WriteString(tt.input)
|
||||
|
||||
reader := newLogReader(buf, true)
|
||||
content, err := io.ReadAll(reader)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to read content: %v", err)
|
||||
}
|
||||
|
||||
actual := string(content)
|
||||
if actual != tt.expected {
|
||||
t.Errorf("Expected: %q, got: %q", tt.expected, actual)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
146
cmd/cli/loop.go
Normal file
146
cmd/cli/loop.go
Normal file
@@ -0,0 +1,146 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld"
|
||||
)
|
||||
|
||||
const (
|
||||
loopTestDomain = ".test"
|
||||
loopTestQtype = dns.TypeTXT
|
||||
)
|
||||
|
||||
// newLoopGuard returns new loopGuard.
|
||||
func newLoopGuard() *loopGuard {
|
||||
return &loopGuard{inflight: make(map[string]struct{})}
|
||||
}
|
||||
|
||||
// loopGuard guards against DNS loop, ensuring only one query
|
||||
// for a given domain is processed at a time.
|
||||
type loopGuard struct {
|
||||
mu sync.Mutex
|
||||
inflight map[string]struct{}
|
||||
}
|
||||
|
||||
// TryLock marks the domain as being processed.
|
||||
func (lg *loopGuard) TryLock(domain string) bool {
|
||||
lg.mu.Lock()
|
||||
defer lg.mu.Unlock()
|
||||
if _, inflight := lg.inflight[domain]; !inflight {
|
||||
lg.inflight[domain] = struct{}{}
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// Unlock marks the domain as being done.
|
||||
func (lg *loopGuard) Unlock(domain string) {
|
||||
lg.mu.Lock()
|
||||
defer lg.mu.Unlock()
|
||||
delete(lg.inflight, domain)
|
||||
}
|
||||
|
||||
// isLoop reports whether the given upstream config is detected as having DNS loop.
|
||||
func (p *prog) isLoop(uc *ctrld.UpstreamConfig) bool {
|
||||
p.loopMu.Lock()
|
||||
defer p.loopMu.Unlock()
|
||||
return p.loop[uc.UID()]
|
||||
}
|
||||
|
||||
// detectLoop checks if the given DNS message is initialized sent by ctrld.
|
||||
// If yes, marking the corresponding upstream as loop, prevent infinite DNS
|
||||
// forwarding loop.
|
||||
//
|
||||
// See p.checkDnsLoop for more details how it works.
|
||||
func (p *prog) detectLoop(msg *dns.Msg) {
|
||||
if len(msg.Question) != 1 {
|
||||
return
|
||||
}
|
||||
q := msg.Question[0]
|
||||
if q.Qtype != loopTestQtype {
|
||||
return
|
||||
}
|
||||
unFQDNname := strings.TrimSuffix(q.Name, ".")
|
||||
uid := strings.TrimSuffix(unFQDNname, loopTestDomain)
|
||||
p.loopMu.Lock()
|
||||
if _, loop := p.loop[uid]; loop {
|
||||
p.loop[uid] = loop
|
||||
}
|
||||
p.loopMu.Unlock()
|
||||
}
|
||||
|
||||
// checkDnsLoop sends a message to check if there's any DNS forwarding loop
|
||||
// with all the upstreams. The way it works based on dnsmasq --dns-loop-detect.
|
||||
//
|
||||
// - Generating a TXT test query and sending it to all upstream.
|
||||
// - The test query is formed by upstream UID and test domain: <uid>.test
|
||||
// - If the test query returns to ctrld, mark the corresponding upstream as loop (see p.detectLoop).
|
||||
//
|
||||
// See: https://thekelleys.org.uk/dnsmasq/docs/dnsmasq-man.html
|
||||
func (p *prog) checkDnsLoop() {
|
||||
p.Debug().Msg("Start checking DNS loop")
|
||||
upstream := make(map[string]*ctrld.UpstreamConfig)
|
||||
p.loopMu.Lock()
|
||||
for n, uc := range p.cfg.Upstream {
|
||||
if p.um.isDown("upstream." + n) {
|
||||
continue
|
||||
}
|
||||
// Do not send test query to external upstream.
|
||||
if !canBeLocalUpstream(uc.Domain) {
|
||||
p.Debug().Msgf("Skipping external: upstream.%s", n)
|
||||
continue
|
||||
}
|
||||
uid := uc.UID()
|
||||
p.loop[uid] = false
|
||||
upstream[uid] = uc
|
||||
}
|
||||
p.loopMu.Unlock()
|
||||
|
||||
loggerCtx := ctrld.LoggerCtx(context.Background(), p.logger.Load())
|
||||
for uid := range p.loop {
|
||||
msg := loopTestMsg(uid)
|
||||
uc := upstream[uid]
|
||||
// Skipping upstream which is being marked as down.
|
||||
if uc == nil {
|
||||
continue
|
||||
}
|
||||
resolver, err := ctrld.NewResolver(loggerCtx, uc)
|
||||
if err != nil {
|
||||
p.Warn().Err(err).Msgf("Could not perform loop check for upstream: %q, endpoint: %q", uc.Name, uc.Endpoint)
|
||||
continue
|
||||
}
|
||||
if _, err := resolver.Resolve(context.Background(), msg); err != nil {
|
||||
p.Warn().Err(err).Msgf("Could not send DNS loop check query for upstream: %q, endpoint: %q", uc.Name, uc.Endpoint)
|
||||
}
|
||||
}
|
||||
p.Debug().Msg("End checking DNS loop")
|
||||
}
|
||||
|
||||
// checkDnsLoopTicker performs p.checkDnsLoop every minute.
|
||||
func (p *prog) checkDnsLoopTicker(ctx context.Context) {
|
||||
timer := time.NewTicker(time.Minute)
|
||||
defer timer.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-p.stopCh:
|
||||
return
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-timer.C:
|
||||
p.checkDnsLoop()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// loopTestMsg creates a DNS test message for loop detection
|
||||
func loopTestMsg(uid string) *dns.Msg {
|
||||
msg := new(dns.Msg)
|
||||
msg.SetQuestion(dns.Fqdn(uid+loopTestDomain), loopTestQtype)
|
||||
return msg
|
||||
}
|
||||
42
cmd/cli/loop_test.go
Normal file
42
cmd/cli/loop_test.go
Normal file
@@ -0,0 +1,42 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func Test_loopGuard(t *testing.T) {
|
||||
lg := newLoopGuard()
|
||||
key := "foo"
|
||||
|
||||
var i atomic.Int64
|
||||
var started atomic.Int64
|
||||
n := 1000
|
||||
do := func() {
|
||||
locked := lg.TryLock(key)
|
||||
defer lg.Unlock(key)
|
||||
started.Add(1)
|
||||
for started.Load() < 2 {
|
||||
// Wait until at least 2 goroutines started, otherwise, on system with heavy load,
|
||||
// or having only 1 CPU, all goroutines can be scheduled to run consequently.
|
||||
}
|
||||
if locked {
|
||||
i.Add(1)
|
||||
}
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(n)
|
||||
for i := 0; i < n; i++ {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
do()
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
if i.Load() == int64(n) {
|
||||
t.Fatalf("i must not be increased %d times", n)
|
||||
}
|
||||
}
|
||||
231
cmd/cli/main.go
Normal file
231
cmd/cli/main.go
Normal file
@@ -0,0 +1,231 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/kardianos/service"
|
||||
"go.uber.org/zap"
|
||||
"go.uber.org/zap/zapcore"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld"
|
||||
)
|
||||
|
||||
// Global variables for CLI configuration and state management
|
||||
// These are used across multiple commands and need to persist throughout the application lifecycle
|
||||
var (
|
||||
configPath string
|
||||
configBase64 string
|
||||
daemon bool
|
||||
listenAddress string
|
||||
primaryUpstream string
|
||||
secondaryUpstream string
|
||||
domains []string
|
||||
logPath string
|
||||
homedir string
|
||||
cacheSize int
|
||||
cfg ctrld.Config
|
||||
verbose int
|
||||
silent bool
|
||||
cdUID string
|
||||
cdOrg string
|
||||
customHostname string
|
||||
cdDev bool
|
||||
iface string
|
||||
ifaceStartStop string
|
||||
nextdns string
|
||||
cdUpstreamProto string
|
||||
deactivationPin int64
|
||||
skipSelfChecks bool
|
||||
cleanup bool
|
||||
startOnly bool
|
||||
rfc1918 bool
|
||||
|
||||
mainLog atomic.Pointer[ctrld.Logger]
|
||||
consoleWriter zapcore.Core
|
||||
consoleWriterLevel zapcore.Level
|
||||
noConfigStart bool
|
||||
)
|
||||
|
||||
// Flag name constants for consistent reference across the codebase
|
||||
// Using constants prevents typos and makes refactoring easier
|
||||
const (
|
||||
cdUidFlagName = "cd"
|
||||
cdOrgFlagName = "cd-org"
|
||||
customHostnameFlagName = "custom-hostname"
|
||||
nextdnsFlagName = "nextdns"
|
||||
)
|
||||
|
||||
// init initializes the default logger before any CLI commands are executed
|
||||
// This ensures logging is available even during early initialization phases
|
||||
func init() {
|
||||
l := zap.NewNop()
|
||||
mainLog.Store(&ctrld.Logger{Logger: l})
|
||||
}
|
||||
|
||||
// Main is the entry point for the CLI application
|
||||
// It initializes configuration, sets up the CLI structure, and executes the root command
|
||||
func Main() {
|
||||
ctrld.InitConfig(v, "ctrld")
|
||||
rootCmd := initCLI()
|
||||
if err := rootCmd.Execute(); err != nil {
|
||||
mainLog.Load().Error().Msg(err.Error())
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
// normalizeLogFilePath converts relative log file paths to absolute paths
|
||||
// This ensures log files are created in predictable locations regardless of working directory
|
||||
func normalizeLogFilePath(logFilePath string) string {
|
||||
if logFilePath == "" || filepath.IsAbs(logFilePath) || service.Interactive() {
|
||||
return logFilePath
|
||||
}
|
||||
if homedir != "" {
|
||||
return filepath.Join(homedir, logFilePath)
|
||||
}
|
||||
dir, _ := userHomeDir()
|
||||
if dir == "" {
|
||||
return logFilePath
|
||||
}
|
||||
return filepath.Join(dir, logFilePath)
|
||||
}
|
||||
|
||||
// initConsoleLogging initializes console logging, then storing to mainLog.
|
||||
// This sets up human-readable logging output for interactive use
|
||||
func initConsoleLogging() {
|
||||
consoleWriterLevel = ctrld.NoticeLevel
|
||||
switch {
|
||||
case silent:
|
||||
// For silent mode, use a no-op logger to suppress all output
|
||||
l := zap.NewNop()
|
||||
mainLog.Store(&ctrld.Logger{Logger: l})
|
||||
case verbose == 1:
|
||||
// Info level provides basic operational information
|
||||
consoleWriterLevel = zapcore.InfoLevel
|
||||
case verbose > 1:
|
||||
// Debug level provides detailed diagnostic information
|
||||
consoleWriterLevel = zapcore.DebugLevel
|
||||
}
|
||||
consoleWriter = newHumanReadableZapCore(os.Stdout, consoleWriterLevel)
|
||||
l := zap.New(consoleWriter)
|
||||
mainLog.Store(&ctrld.Logger{Logger: l})
|
||||
}
|
||||
|
||||
// initInteractiveLogging is like initLogging, but the ProxyLogger is discarded
|
||||
// to be used for all interactive commands.
|
||||
//
|
||||
// Current log file config will also be ignored.
|
||||
// This prevents log file conflicts during interactive command execution
|
||||
func initInteractiveLogging() {
|
||||
old := cfg.Service.LogPath
|
||||
cfg.Service.LogPath = ""
|
||||
initLoggingWithBackup(false)
|
||||
cfg.Service.LogPath = old
|
||||
}
|
||||
|
||||
// initLoggingWithBackup initializes log setup base on current config.
|
||||
// If doBackup is true, backup old log file with ".1" suffix.
|
||||
//
|
||||
// This is only used in runCmd for special handling in case of logging config
|
||||
// change in cd mode. Without special reason, the caller should use initLogging
|
||||
// wrapper instead of calling this function directly.
|
||||
func initLoggingWithBackup(doBackup bool) []zapcore.Core {
|
||||
var writers []io.Writer
|
||||
if logFilePath := normalizeLogFilePath(cfg.Service.LogPath); logFilePath != "" {
|
||||
// Create parent directory if necessary.
|
||||
// This ensures log files can be created even if the directory doesn't exist
|
||||
if err := os.MkdirAll(filepath.Dir(logFilePath), 0750); err != nil {
|
||||
mainLog.Load().Error().Msgf("Failed to create log path: %v", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
// Default open log file in append mode.
|
||||
// This preserves existing log entries across restarts
|
||||
flags := os.O_CREATE | os.O_RDWR | os.O_APPEND
|
||||
if doBackup {
|
||||
// Backup old log file with .1 suffix.
|
||||
// This prevents log file corruption during rotation
|
||||
if err := os.Rename(logFilePath, logFilePath+oldLogSuffix); err != nil && !os.IsNotExist(err) {
|
||||
mainLog.Load().Error().Msgf("Could not backup old log file: %v", err)
|
||||
} else {
|
||||
// Backup was created, set flags for truncating old log file.
|
||||
// This ensures a clean start for the new log file
|
||||
flags = os.O_CREATE | os.O_RDWR
|
||||
}
|
||||
}
|
||||
logFile, err := openLogFile(logFilePath, flags)
|
||||
if err != nil {
|
||||
mainLog.Load().Error().Msgf("Failed to create log file: %v", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
writers = append(writers, logFile)
|
||||
}
|
||||
|
||||
// Create zap cores for different writers
|
||||
// Multiple cores allow logging to both console and file simultaneously
|
||||
var cores []zapcore.Core
|
||||
cores = append(cores, consoleWriter)
|
||||
|
||||
// Determine log level based on verbosity and configuration
|
||||
// This provides flexible logging control for different use cases
|
||||
logLevel := cfg.Service.LogLevel
|
||||
switch {
|
||||
case silent:
|
||||
// For silent mode, use a no-op logger to suppress all output
|
||||
l := zap.NewNop()
|
||||
mainLog.Store(&ctrld.Logger{Logger: l})
|
||||
return cores
|
||||
case verbose == 1:
|
||||
logLevel = "info"
|
||||
case verbose > 1:
|
||||
logLevel = "debug"
|
||||
}
|
||||
|
||||
// Parse log level string to zapcore.Level
|
||||
// This provides human-readable log level configuration
|
||||
var level zapcore.Level
|
||||
switch logLevel {
|
||||
case "debug":
|
||||
level = zapcore.DebugLevel
|
||||
case "info":
|
||||
level = zapcore.InfoLevel
|
||||
case "notice":
|
||||
level = ctrld.NoticeLevel
|
||||
case "warn":
|
||||
level = zapcore.WarnLevel
|
||||
case "error":
|
||||
level = zapcore.ErrorLevel
|
||||
default:
|
||||
level = zapcore.InfoLevel // default level
|
||||
}
|
||||
|
||||
consoleWriter.Enabled(level)
|
||||
// Add cores for all writers
|
||||
// This enables multi-destination logging (console + file)
|
||||
for _, writer := range writers {
|
||||
core := newMachineFriendlyZapCore(writer, level)
|
||||
cores = append(cores, core)
|
||||
}
|
||||
|
||||
// Create a multi-core logger
|
||||
// This allows simultaneous logging to multiple destinations
|
||||
multiCore := zapcore.NewTee(cores...)
|
||||
logger := zap.New(multiCore)
|
||||
mainLog.Store(&ctrld.Logger{Logger: logger})
|
||||
|
||||
return cores
|
||||
}
|
||||
|
||||
// initCache initializes DNS cache configuration
|
||||
// This improves performance by caching frequently requested DNS responses
|
||||
func initCache() {
|
||||
if !cfg.Service.CacheEnable {
|
||||
return
|
||||
}
|
||||
if cfg.Service.CacheSize == 0 {
|
||||
// Default cache size provides good balance between memory usage and performance
|
||||
cfg.Service.CacheSize = 4096
|
||||
}
|
||||
}
|
||||
32
cmd/cli/main_test.go
Normal file
32
cmd/cli/main_test.go
Normal file
@@ -0,0 +1,32 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"go.uber.org/zap"
|
||||
"go.uber.org/zap/zapcore"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld"
|
||||
)
|
||||
|
||||
var logOutput strings.Builder
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
// Create a custom writer that writes to logOutput
|
||||
writer := zapcore.AddSync(&logOutput)
|
||||
|
||||
// Create zap encoder
|
||||
encoderConfig := zap.NewDevelopmentEncoderConfig()
|
||||
encoder := zapcore.NewConsoleEncoder(encoderConfig)
|
||||
|
||||
// Create core that writes to our string builder
|
||||
core := zapcore.NewCore(encoder, writer, zap.DebugLevel)
|
||||
|
||||
// Create logger
|
||||
l := zap.New(core)
|
||||
|
||||
mainLog.Store(&ctrld.Logger{Logger: l})
|
||||
os.Exit(m.Run())
|
||||
}
|
||||
157
cmd/cli/metrics.go
Normal file
157
cmd/cli/metrics.go
Normal file
@@ -0,0 +1,157 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net"
|
||||
"net/http"
|
||||
"runtime"
|
||||
"time"
|
||||
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"github.com/prometheus/client_golang/prometheus/collectors"
|
||||
"github.com/prometheus/client_golang/prometheus/promhttp"
|
||||
"github.com/prometheus/prom2json"
|
||||
)
|
||||
|
||||
// metricsServer represents a server to expose Prometheus metrics via HTTP.
|
||||
// This provides monitoring and observability for the DNS proxy service
|
||||
type metricsServer struct {
|
||||
server *http.Server
|
||||
mux *http.ServeMux
|
||||
reg *prometheus.Registry
|
||||
addr string
|
||||
started bool
|
||||
}
|
||||
|
||||
// newMetricsServer returns new metrics server.
|
||||
// This initializes the HTTP server for exposing Prometheus metrics
|
||||
func newMetricsServer(addr string, reg *prometheus.Registry) (*metricsServer, error) {
|
||||
mux := http.NewServeMux()
|
||||
ms := &metricsServer{
|
||||
server: &http.Server{Handler: mux},
|
||||
mux: mux,
|
||||
reg: reg,
|
||||
}
|
||||
ms.addr = addr
|
||||
ms.registerMetricsServerHandler()
|
||||
return ms, nil
|
||||
}
|
||||
|
||||
// register adds handlers for given pattern.
|
||||
// This provides a clean interface for adding HTTP endpoints to the metrics server
|
||||
func (ms *metricsServer) register(pattern string, handler http.Handler) {
|
||||
ms.mux.Handle(pattern, handler)
|
||||
}
|
||||
|
||||
// registerMetricsServerHandler adds handlers for metrics server.
|
||||
// This sets up both Prometheus format and JSON format endpoints for metrics
|
||||
func (ms *metricsServer) registerMetricsServerHandler() {
|
||||
ms.register("/metrics", promhttp.HandlerFor(
|
||||
ms.reg,
|
||||
promhttp.HandlerOpts{
|
||||
EnableOpenMetrics: true,
|
||||
Timeout: 10 * time.Second,
|
||||
},
|
||||
))
|
||||
ms.register("/metrics/json", jsonResponse(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
g := prometheus.ToTransactionalGatherer(ms.reg)
|
||||
mfs, done, err := g.Gather()
|
||||
defer done()
|
||||
if err != nil {
|
||||
msg := "could not gather metrics"
|
||||
mainLog.Load().Warn().Err(err).Msg(msg)
|
||||
http.Error(w, msg, http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
result := make([]*prom2json.Family, 0, len(mfs))
|
||||
for _, mf := range mfs {
|
||||
result = append(result, prom2json.NewFamily(mf))
|
||||
}
|
||||
if err := json.NewEncoder(w).Encode(result); err != nil {
|
||||
msg := "could not marshal metrics result"
|
||||
mainLog.Load().Warn().Err(err).Msg(msg)
|
||||
http.Error(w, msg, http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
})))
|
||||
}
|
||||
|
||||
// start runs the metricsServer.
|
||||
// This starts the HTTP server for metrics exposure
|
||||
func (ms *metricsServer) start() error {
|
||||
listener, err := net.Listen("tcp", ms.addr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
go ms.server.Serve(listener)
|
||||
ms.started = true
|
||||
return nil
|
||||
}
|
||||
|
||||
// stop shutdowns the metricsServer within 2 seconds timeout.
|
||||
// This ensures graceful shutdown of the metrics server
|
||||
func (ms *metricsServer) stop() error {
|
||||
if !ms.started {
|
||||
return nil
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*1)
|
||||
defer cancel()
|
||||
return ms.server.Shutdown(ctx)
|
||||
}
|
||||
|
||||
// runMetricsServer initializes metrics stats and runs the metrics server if enabled.
|
||||
// This sets up the complete metrics infrastructure including Prometheus collectors
|
||||
func (p *prog) runMetricsServer(ctx context.Context, reloadCh chan struct{}) {
|
||||
if !p.metricsEnabled() {
|
||||
return
|
||||
}
|
||||
|
||||
// Reset all stats.
|
||||
statsVersion.Reset()
|
||||
statsQueriesCount.Reset()
|
||||
statsClientQueriesCount.Reset()
|
||||
|
||||
reg := prometheus.NewRegistry()
|
||||
// Register queries count stats if enabled.
|
||||
if p.metricsQueryStats.Load() {
|
||||
reg.MustRegister(statsQueriesCount)
|
||||
reg.MustRegister(statsClientQueriesCount)
|
||||
}
|
||||
|
||||
addr := p.cfg.Service.MetricsListener
|
||||
ms, err := newMetricsServer(addr, reg)
|
||||
if err != nil {
|
||||
mainLog.Load().Warn().Err(err).Msg("Could not create new metrics server")
|
||||
return
|
||||
}
|
||||
// Only start listener address if defined.
|
||||
if addr != "" {
|
||||
// Go runtime stats.
|
||||
reg.MustRegister(collectors.NewBuildInfoCollector())
|
||||
reg.MustRegister(collectors.NewGoCollector(
|
||||
collectors.WithGoCollectorRuntimeMetrics(collectors.MetricsAll),
|
||||
))
|
||||
// ctrld stats.
|
||||
reg.MustRegister(statsVersion)
|
||||
statsVersion.WithLabelValues(commit, runtime.Version(), curVersion()).Inc()
|
||||
reg.MustRegister(statsTimeStart)
|
||||
statsTimeStart.Set(float64(time.Now().Unix()))
|
||||
mainLog.Load().Debug().Msgf("Starting metrics server on: %s", addr)
|
||||
if err := ms.start(); err != nil {
|
||||
mainLog.Load().Warn().Err(err).Msg("Could not start metrics server")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
select {
|
||||
case <-p.stopCh:
|
||||
case <-ctx.Done():
|
||||
case <-reloadCh:
|
||||
}
|
||||
|
||||
if err := ms.stop(); err != nil {
|
||||
mainLog.Load().Warn().Err(err).Msg("Could not stop metrics server")
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package main
|
||||
package cli
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
@@ -9,17 +9,18 @@ import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
func patchNetIfaceName(iface *net.Interface) error {
|
||||
func patchNetIfaceName(iface *net.Interface) (bool, error) {
|
||||
b, err := exec.Command("networksetup", "-listnetworkserviceorder").Output()
|
||||
if err != nil {
|
||||
return err
|
||||
return false, err
|
||||
}
|
||||
|
||||
patched := false
|
||||
if name := networkServiceName(iface.Name, bytes.NewReader(b)); name != "" {
|
||||
patched = true
|
||||
iface.Name = name
|
||||
mainLog.Debug().Str("network_service", name).Msg("found network service name for interface")
|
||||
}
|
||||
return nil
|
||||
return patched, nil
|
||||
}
|
||||
|
||||
func networkServiceName(ifaceName string, r io.Reader) string {
|
||||
@@ -42,3 +43,9 @@ func networkServiceName(ifaceName string, r io.Reader) string {
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// validInterface reports whether the *net.Interface is a valid one.
|
||||
func validInterface(iface *net.Interface, validIfacesMap map[string]struct{}) bool {
|
||||
_, ok := validIfacesMap[iface.Name]
|
||||
return ok
|
||||
}
|
||||
17
cmd/cli/net_linux.go
Normal file
17
cmd/cli/net_linux.go
Normal file
@@ -0,0 +1,17 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"net"
|
||||
)
|
||||
|
||||
// patchNetIfaceName patches network interface names on Linux
|
||||
// This is a no-op on Linux as interface names don't need special handling
|
||||
func patchNetIfaceName(iface *net.Interface) (bool, error) { return true, nil }
|
||||
|
||||
// validInterface reports whether the *net.Interface is a valid one.
|
||||
// Only non-virtual interfaces are considered valid.
|
||||
// This prevents DNS configuration on virtual interfaces like docker, veth, etc.
|
||||
func validInterface(iface *net.Interface, validIfacesMap map[string]struct{}) bool {
|
||||
_, ok := validIfacesMap[iface.Name]
|
||||
return ok
|
||||
}
|
||||
13
cmd/cli/net_others.go
Normal file
13
cmd/cli/net_others.go
Normal file
@@ -0,0 +1,13 @@
|
||||
//go:build !darwin && !windows && !linux
|
||||
|
||||
package cli
|
||||
|
||||
import (
|
||||
"net"
|
||||
)
|
||||
|
||||
// patchNetIfaceName patches network interface names on non-Linux/Darwin platforms
|
||||
func patchNetIfaceName(iface *net.Interface) (bool, error) { return true, nil }
|
||||
|
||||
// validInterface checks if an interface is valid on non-Linux/Darwin platforms
|
||||
func validInterface(iface *net.Interface, validIfacesMap map[string]struct{}) bool { return true }
|
||||
16
cmd/cli/net_windows.go
Normal file
16
cmd/cli/net_windows.go
Normal file
@@ -0,0 +1,16 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"net"
|
||||
)
|
||||
|
||||
func patchNetIfaceName(iface *net.Interface) (bool, error) {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// validInterface reports whether the *net.Interface is a valid one.
|
||||
// On Windows, only physical interfaces are considered valid.
|
||||
func validInterface(iface *net.Interface, validIfacesMap map[string]struct{}) bool {
|
||||
_, ok := validIfacesMap[iface.Name]
|
||||
return ok
|
||||
}
|
||||
47
cmd/cli/net_windows_test.go
Normal file
47
cmd/cli/net_windows_test.go
Normal file
@@ -0,0 +1,47 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"maps"
|
||||
"slices"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld"
|
||||
)
|
||||
|
||||
func Test_validInterfaces(t *testing.T) {
|
||||
verbose = 3
|
||||
initConsoleLogging()
|
||||
start := time.Now()
|
||||
im := ctrld.ValidInterfaces(ctrld.LoggerCtx(context.Background(), mainLog.Load()))
|
||||
t.Logf("Using Windows API takes: %d", time.Since(start).Milliseconds())
|
||||
ifaces := slices.Collect(maps.Keys(im))
|
||||
|
||||
start = time.Now()
|
||||
ifacesPowershell := validInterfacesPowershell()
|
||||
t.Logf("Using Powershell takes: %d", time.Since(start).Milliseconds())
|
||||
|
||||
slices.Sort(ifaces)
|
||||
slices.Sort(ifacesPowershell)
|
||||
if !slices.Equal(ifaces, ifacesPowershell) {
|
||||
t.Fatalf("result mismatch, want: %v, got: %v", ifacesPowershell, ifaces)
|
||||
}
|
||||
}
|
||||
|
||||
func validInterfacesPowershell() []string {
|
||||
out, err := powershell("Get-NetAdapter -Physical | Select-Object -ExpandProperty Name")
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
var res []string
|
||||
scanner := bufio.NewScanner(bytes.NewReader(out))
|
||||
for scanner.Scan() {
|
||||
ifaceName := strings.TrimSpace(scanner.Text())
|
||||
res = append(res, ifaceName)
|
||||
}
|
||||
return res
|
||||
}
|
||||
36
cmd/cli/netlink_linux.go
Normal file
36
cmd/cli/netlink_linux.go
Normal file
@@ -0,0 +1,36 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/vishvananda/netlink"
|
||||
"golang.org/x/sys/unix"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld"
|
||||
)
|
||||
|
||||
func (p *prog) watchLinkState(ctx context.Context) {
|
||||
ch := make(chan netlink.LinkUpdate)
|
||||
done := make(chan struct{})
|
||||
defer close(done)
|
||||
if err := netlink.LinkSubscribe(ch, done); err != nil {
|
||||
p.Warn().Err(err).Msg("Could not subscribe link")
|
||||
return
|
||||
}
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case lu := <-ch:
|
||||
if lu.Change == 0xFFFFFFFF {
|
||||
continue
|
||||
}
|
||||
if lu.Change&unix.IFF_UP != 0 {
|
||||
p.Debug().Msgf("Link state changed, re-bootstrapping")
|
||||
for _, uc := range p.cfg.Upstream {
|
||||
uc.ReBootstrap(ctrld.LoggerCtx(ctx, p.logger.Load()))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
7
cmd/cli/netlink_others.go
Normal file
7
cmd/cli/netlink_others.go
Normal file
@@ -0,0 +1,7 @@
|
||||
//go:build !linux
|
||||
|
||||
package cli
|
||||
|
||||
import "context"
|
||||
|
||||
func (p *prog) watchLinkState(ctx context.Context) {}
|
||||
90
cmd/cli/network_manager_linux.go
Normal file
90
cmd/cli/network_manager_linux.go
Normal file
@@ -0,0 +1,90 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"github.com/coreos/go-systemd/v22/dbus"
|
||||
)
|
||||
|
||||
const (
|
||||
nmConfDir = "/etc/NetworkManager/conf.d"
|
||||
nmCtrldConfFilename = "99-ctrld.conf"
|
||||
nmCtrldConfContent = `[main]
|
||||
dns=none
|
||||
systemd-resolved=false
|
||||
`
|
||||
nmSystemdUnitName = "NetworkManager.service"
|
||||
)
|
||||
|
||||
var networkManagerCtrldConfFile = filepath.Join(nmConfDir, nmCtrldConfFilename)
|
||||
|
||||
// hasNetworkManager reports whether NetworkManager executable found.
|
||||
// hasNetworkManager checks if NetworkManager is available on the system
|
||||
func hasNetworkManager() bool {
|
||||
exe, _ := exec.LookPath("NetworkManager")
|
||||
return exe != ""
|
||||
}
|
||||
|
||||
func (p *prog) setupNetworkManager() error {
|
||||
if !hasNetworkManager() {
|
||||
return nil
|
||||
}
|
||||
if content, _ := os.ReadFile(nmCtrldConfContent); string(content) == nmCtrldConfContent {
|
||||
p.Debug().Msg("NetworkManager already setup, nothing to do")
|
||||
return nil
|
||||
}
|
||||
err := os.WriteFile(networkManagerCtrldConfFile, []byte(nmCtrldConfContent), os.FileMode(0644))
|
||||
if os.IsNotExist(err) {
|
||||
p.Debug().Msg("NetworkManager is not available")
|
||||
return nil
|
||||
}
|
||||
if err != nil {
|
||||
p.Debug().Err(err).Msg("Could not write NetworkManager ctrld config file")
|
||||
return err
|
||||
}
|
||||
|
||||
p.reloadNetworkManager()
|
||||
p.Debug().Msg("Setup NetworkManager done")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *prog) restoreNetworkManager() error {
|
||||
if !hasNetworkManager() {
|
||||
return nil
|
||||
}
|
||||
err := os.Remove(networkManagerCtrldConfFile)
|
||||
if os.IsNotExist(err) {
|
||||
p.Debug().Msg("NetworkManager is not available")
|
||||
return nil
|
||||
}
|
||||
if err != nil {
|
||||
p.Debug().Err(err).Msg("Could not remove NetworkManager ctrld config file")
|
||||
return err
|
||||
}
|
||||
|
||||
p.reloadNetworkManager()
|
||||
p.Debug().Msg("Restore NetworkManager done")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *prog) reloadNetworkManager() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
|
||||
defer cancel()
|
||||
conn, err := dbus.NewSystemConnectionContext(ctx)
|
||||
if err != nil {
|
||||
p.Error().Err(err).Msg("Could not create new system connection")
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
waitCh := make(chan string)
|
||||
if _, err := conn.ReloadUnitContext(ctx, nmSystemdUnitName, "ignore-dependencies", waitCh); err != nil {
|
||||
p.Debug().Err(err).Msg("Could not reload NetworkManager")
|
||||
return
|
||||
}
|
||||
<-waitCh
|
||||
}
|
||||
15
cmd/cli/network_manager_others.go
Normal file
15
cmd/cli/network_manager_others.go
Normal file
@@ -0,0 +1,15 @@
|
||||
//go:build !linux
|
||||
|
||||
package cli
|
||||
|
||||
func (p *prog) setupNetworkManager() error {
|
||||
p.reloadNetworkManager()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *prog) restoreNetworkManager() error {
|
||||
p.reloadNetworkManager()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *prog) reloadNetworkManager() {}
|
||||
32
cmd/cli/nextdns.go
Normal file
32
cmd/cli/nextdns.go
Normal file
@@ -0,0 +1,32 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld"
|
||||
)
|
||||
|
||||
const nextdnsURL = "https://dns.nextdns.io"
|
||||
|
||||
// generateNextDNSConfig generates NextDNS configuration for the given UID
|
||||
func generateNextDNSConfig(uid string) {
|
||||
if uid == "" {
|
||||
return
|
||||
}
|
||||
mainLog.Load().Info().Msg("Generating ctrld config for NextDNS resolver")
|
||||
cfg = ctrld.Config{
|
||||
Listener: map[string]*ctrld.ListenerConfig{
|
||||
"0": {
|
||||
IP: "0.0.0.0",
|
||||
Port: 53,
|
||||
},
|
||||
},
|
||||
Upstream: map[string]*ctrld.UpstreamConfig{
|
||||
"0": {
|
||||
Type: ctrld.ResolverTypeDOH3,
|
||||
Endpoint: fmt.Sprintf("%s/%s", nextdnsURL, uid),
|
||||
Timeout: 5000,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
5
cmd/cli/nocgo.go
Normal file
5
cmd/cli/nocgo.go
Normal file
@@ -0,0 +1,5 @@
|
||||
//go:build !cgo
|
||||
|
||||
package cli
|
||||
|
||||
const cgoEnabled = false
|
||||
128
cmd/cli/os_darwin.go
Normal file
128
cmd/cli/os_darwin.go
Normal file
@@ -0,0 +1,128 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"fmt"
|
||||
"net"
|
||||
"os/exec"
|
||||
"strings"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld"
|
||||
)
|
||||
|
||||
// allocateIP allocates an IP address on the specified interface
|
||||
// sudo ifconfig lo0 alias 127.0.0.2 up
|
||||
func allocateIP(ip string) error {
|
||||
mainLog.Load().Debug().Str("ip", ip).Msg("Allocating IP address")
|
||||
cmd := exec.Command("ifconfig", "lo0", "alias", ip, "up")
|
||||
if err := cmd.Run(); err != nil {
|
||||
mainLog.Load().Error().Err(err).Msg("AllocateIP failed")
|
||||
return err
|
||||
}
|
||||
mainLog.Load().Debug().Str("ip", ip).Msg("IP address allocated successfully")
|
||||
return nil
|
||||
}
|
||||
|
||||
// deAllocateIP deallocates an IP address from the specified interface
|
||||
func deAllocateIP(ip string) error {
|
||||
mainLog.Load().Debug().Str("ip", ip).Msg("Deallocating IP address")
|
||||
cmd := exec.Command("ifconfig", "lo0", "-alias", ip)
|
||||
if err := cmd.Run(); err != nil {
|
||||
mainLog.Load().Error().Err(err).Msg("DeAllocateIP failed")
|
||||
return err
|
||||
}
|
||||
mainLog.Load().Debug().Str("ip", ip).Msg("IP address deallocated successfully")
|
||||
return nil
|
||||
}
|
||||
|
||||
// setDnsIgnoreUnusableInterface likes setDNS, but return a nil error if the interface is not usable.
|
||||
func setDnsIgnoreUnusableInterface(iface *net.Interface, nameservers []string) error {
|
||||
if err := setDNS(iface, nameservers); err != nil {
|
||||
// TODO: investiate whether we can detect this without relying on error message.
|
||||
if strings.Contains(err.Error(), " is not a recognized network service") {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// set the dns server for the provided network interface
|
||||
// networksetup -setdnsservers Wi-Fi 8.8.8.8 1.1.1.1
|
||||
// TODO(cuonglm): use system API
|
||||
func setDNS(iface *net.Interface, nameservers []string) error {
|
||||
mainLog.Load().Debug().Str("interface", iface.Name).Strs("nameservers", nameservers).Msg("Setting DNS configuration")
|
||||
|
||||
// Note that networksetup won't modify search domains settings,
|
||||
// This assignment is just a placeholder to silent linter.
|
||||
_ = searchDomains
|
||||
cmd := "networksetup"
|
||||
args := []string{"-setdnsservers", iface.Name}
|
||||
args = append(args, nameservers...)
|
||||
if out, err := exec.Command(cmd, args...).CombinedOutput(); err != nil {
|
||||
return fmt.Errorf("%v: %w", string(out), err)
|
||||
}
|
||||
|
||||
mainLog.Load().Debug().Str("interface", iface.Name).Msg("DNS configuration set successfully")
|
||||
return nil
|
||||
}
|
||||
|
||||
// resetDnsIgnoreUnusableInterface likes resetDNS, but return a nil error if the interface is not usable.
|
||||
func resetDnsIgnoreUnusableInterface(iface *net.Interface) error {
|
||||
if err := resetDNS(iface); err != nil {
|
||||
// TODO: investiate whether we can detect this without relying on error message.
|
||||
if strings.Contains(err.Error(), " is not a recognized network service") {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// TODO(cuonglm): use system API
|
||||
func resetDNS(iface *net.Interface) error {
|
||||
mainLog.Load().Debug().Str("interface", iface.Name).Msg("Resetting DNS configuration")
|
||||
|
||||
cmd := "networksetup"
|
||||
args := []string{"-setdnsservers", iface.Name, "empty"}
|
||||
if out, err := exec.Command(cmd, args...).CombinedOutput(); err != nil {
|
||||
return fmt.Errorf("%v: %w", string(out), err)
|
||||
}
|
||||
|
||||
mainLog.Load().Debug().Str("interface", iface.Name).Msg("DNS configuration reset successfully")
|
||||
return nil
|
||||
}
|
||||
|
||||
// restoreDNS restores the DNS settings of the given interface.
|
||||
// this should only be executed upon turning off the ctrld service.
|
||||
func restoreDNS(iface *net.Interface) (err error) {
|
||||
if ns := ctrld.SavedStaticNameservers(iface); len(ns) > 0 {
|
||||
err = setDNS(iface, ns)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// currentDNS returns the current DNS servers for the specified interface
|
||||
func currentDNS(_ *net.Interface) []string {
|
||||
return ctrld.CurrentNameserversFromResolvconf()
|
||||
}
|
||||
|
||||
// currentStaticDNS returns the current static DNS settings of given interface.
|
||||
func currentStaticDNS(iface *net.Interface) ([]string, error) {
|
||||
cmd := "networksetup"
|
||||
args := []string{"-getdnsservers", iface.Name}
|
||||
out, err := exec.Command(cmd, args...).Output()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
scanner := bufio.NewScanner(bytes.NewReader(out))
|
||||
var ns []string
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if ip := net.ParseIP(line); ip != nil {
|
||||
ns = append(ns, ip.String())
|
||||
}
|
||||
}
|
||||
return ns, nil
|
||||
}
|
||||
118
cmd/cli/os_freebsd.go
Normal file
118
cmd/cli/os_freebsd.go
Normal file
@@ -0,0 +1,118 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"os/exec"
|
||||
|
||||
"tailscale.com/control/controlknobs"
|
||||
"tailscale.com/health"
|
||||
"tailscale.com/util/dnsname"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld"
|
||||
"github.com/Control-D-Inc/ctrld/internal/dns"
|
||||
)
|
||||
|
||||
// allocateIP allocates an IP address on the specified interface
|
||||
// sudo ifconfig lo0 127.0.0.53 alias
|
||||
func allocateIP(ip string) error {
|
||||
mainLog.Load().Debug().Str("ip", ip).Msg("Allocating IP address")
|
||||
cmd := exec.Command("ifconfig", "lo0", ip, "alias")
|
||||
if err := cmd.Run(); err != nil {
|
||||
mainLog.Load().Error().Err(err).Msg("allocateIP failed")
|
||||
return err
|
||||
}
|
||||
mainLog.Load().Debug().Str("ip", ip).Msg("IP address allocated successfully")
|
||||
return nil
|
||||
}
|
||||
|
||||
// deAllocateIP deallocates an IP address from the specified interface
|
||||
func deAllocateIP(ip string) error {
|
||||
mainLog.Load().Debug().Str("ip", ip).Msg("Deallocating IP address")
|
||||
cmd := exec.Command("ifconfig", "lo0", ip, "-alias")
|
||||
if err := cmd.Run(); err != nil {
|
||||
mainLog.Load().Error().Err(err).Msg("deAllocateIP failed")
|
||||
return err
|
||||
}
|
||||
mainLog.Load().Debug().Str("ip", ip).Msg("IP address deallocated successfully")
|
||||
return nil
|
||||
}
|
||||
|
||||
// setDnsIgnoreUnusableInterface likes setDNS, but return a nil error if the interface is not usable.
|
||||
func setDnsIgnoreUnusableInterface(iface *net.Interface, nameservers []string) error {
|
||||
return setDNS(iface, nameservers)
|
||||
}
|
||||
|
||||
// set the dns server for the provided network interface
|
||||
func setDNS(iface *net.Interface, nameservers []string) error {
|
||||
mainLog.Load().Debug().Str("interface", iface.Name).Strs("nameservers", nameservers).Msg("Setting DNS configuration")
|
||||
|
||||
r, err := dns.NewOSConfigurator(logf, &health.Tracker{}, &controlknobs.Knobs{}, iface.Name)
|
||||
if err != nil {
|
||||
mainLog.Load().Error().Err(err).Msg("Failed to create DNS OS configurator")
|
||||
return err
|
||||
}
|
||||
|
||||
ns := make([]netip.Addr, 0, len(nameservers))
|
||||
for _, nameserver := range nameservers {
|
||||
ns = append(ns, netip.MustParseAddr(nameserver))
|
||||
}
|
||||
|
||||
osConfig := dns.OSConfig{
|
||||
Nameservers: ns,
|
||||
SearchDomains: []dnsname.FQDN{},
|
||||
}
|
||||
if sds, err := searchDomains(); err == nil {
|
||||
osConfig.SearchDomains = sds
|
||||
} else {
|
||||
mainLog.Load().Debug().Err(err).Msg("Failed to get search domains list")
|
||||
}
|
||||
|
||||
if err := r.SetDNS(osConfig); err != nil {
|
||||
mainLog.Load().Error().Err(err).Msg("Failed to set DNS")
|
||||
return err
|
||||
}
|
||||
|
||||
mainLog.Load().Debug().Str("interface", iface.Name).Msg("DNS configuration set successfully")
|
||||
return nil
|
||||
}
|
||||
|
||||
// resetDnsIgnoreUnusableInterface likes resetDNS, but return a nil error if the interface is not usable.
|
||||
func resetDnsIgnoreUnusableInterface(iface *net.Interface) error {
|
||||
return resetDNS(iface)
|
||||
}
|
||||
|
||||
// resetDNS resets DNS servers for the specified interface
|
||||
func resetDNS(iface *net.Interface) error {
|
||||
mainLog.Load().Debug().Str("interface", iface.Name).Msg("Resetting DNS configuration")
|
||||
|
||||
r, err := dns.NewOSConfigurator(logf, &health.Tracker{}, &controlknobs.Knobs{}, iface.Name)
|
||||
if err != nil {
|
||||
mainLog.Load().Error().Err(err).Msg("Failed to create DNS OS configurator")
|
||||
return err
|
||||
}
|
||||
|
||||
if err := r.Close(); err != nil {
|
||||
mainLog.Load().Error().Err(err).Msg("Failed to rollback DNS setting")
|
||||
return err
|
||||
}
|
||||
|
||||
mainLog.Load().Debug().Str("interface", iface.Name).Msg("DNS configuration reset successfully")
|
||||
return nil
|
||||
}
|
||||
|
||||
// restoreDNS restores the DNS settings of the given interface.
|
||||
// this should only be executed upon turning off the ctrld service.
|
||||
func restoreDNS(iface *net.Interface) (err error) {
|
||||
return err
|
||||
}
|
||||
|
||||
// currentDNS returns the current DNS servers for the specified interface
|
||||
func currentDNS(_ *net.Interface) []string {
|
||||
return ctrld.CurrentNameserversFromResolvconf()
|
||||
}
|
||||
|
||||
// currentStaticDNS returns the current static DNS settings of given interface.
|
||||
func currentStaticDNS(iface *net.Interface) ([]string, error) {
|
||||
return currentDNS(iface), nil
|
||||
}
|
||||
318
cmd/cli/os_linux.go
Normal file
318
cmd/cli/os_linux.go
Normal file
@@ -0,0 +1,318 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os/exec"
|
||||
"slices"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/insomniacslk/dhcp/dhcpv4/nclient4"
|
||||
"github.com/insomniacslk/dhcp/dhcpv6"
|
||||
"github.com/insomniacslk/dhcp/dhcpv6/client6"
|
||||
"tailscale.com/control/controlknobs"
|
||||
"tailscale.com/health"
|
||||
"tailscale.com/util/dnsname"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld"
|
||||
"github.com/Control-D-Inc/ctrld/internal/dns"
|
||||
ctrldnet "github.com/Control-D-Inc/ctrld/internal/net"
|
||||
)
|
||||
|
||||
const resolvConfBackupFailedMsg = "open /etc/resolv.pre-ctrld-backup.conf: read-only file system"
|
||||
|
||||
type getDNS func(iface string) []string
|
||||
|
||||
// allocate loopback ip
|
||||
// sudo ip a add 127.0.0.2/24 dev lo
|
||||
func allocateIP(ip string) error {
|
||||
mainLog.Load().Debug().Str("ip", ip).Msg("Allocating IP address")
|
||||
cmd := exec.Command("ip", "a", "add", ip+"/24", "dev", "lo")
|
||||
if out, err := cmd.CombinedOutput(); err != nil {
|
||||
mainLog.Load().Error().Err(err).Msgf("AllocateIP failed: %s", string(out))
|
||||
return err
|
||||
}
|
||||
mainLog.Load().Debug().Str("ip", ip).Msg("IP address allocated successfully")
|
||||
return nil
|
||||
}
|
||||
|
||||
func deAllocateIP(ip string) error {
|
||||
mainLog.Load().Debug().Str("ip", ip).Msg("Deallocating IP address")
|
||||
cmd := exec.Command("ip", "a", "del", ip+"/24", "dev", "lo")
|
||||
if err := cmd.Run(); err != nil {
|
||||
mainLog.Load().Error().Err(err).Msg("DeAllocateIP failed")
|
||||
return err
|
||||
}
|
||||
mainLog.Load().Debug().Str("ip", ip).Msg("IP address deallocated successfully")
|
||||
return nil
|
||||
}
|
||||
|
||||
const maxSetDNSAttempts = 5
|
||||
|
||||
// setDnsIgnoreUnusableInterface likes setDNS, but return a nil error if the interface is not usable.
|
||||
func setDnsIgnoreUnusableInterface(iface *net.Interface, nameservers []string) error {
|
||||
return setDNS(iface, nameservers)
|
||||
}
|
||||
|
||||
func setDNS(iface *net.Interface, nameservers []string) error {
|
||||
mainLog.Load().Debug().Str("interface", iface.Name).Strs("nameservers", nameservers).Msg("Setting DNS configuration")
|
||||
|
||||
r, err := dns.NewOSConfigurator(logf, &health.Tracker{}, &controlknobs.Knobs{}, iface.Name)
|
||||
if err != nil {
|
||||
mainLog.Load().Error().Err(err).Msg("Failed to create dns os configurator")
|
||||
return err
|
||||
}
|
||||
|
||||
ns := make([]netip.Addr, 0, len(nameservers))
|
||||
for _, nameserver := range nameservers {
|
||||
ns = append(ns, netip.MustParseAddr(nameserver))
|
||||
}
|
||||
|
||||
osConfig := dns.OSConfig{
|
||||
Nameservers: ns,
|
||||
SearchDomains: []dnsname.FQDN{},
|
||||
}
|
||||
if sds, err := searchDomains(); err == nil {
|
||||
osConfig.SearchDomains = sds
|
||||
} else {
|
||||
mainLog.Load().Debug().Err(err).Msg("Failed to get search domains list")
|
||||
}
|
||||
trySystemdResolve := false
|
||||
if err := r.SetDNS(osConfig); err != nil {
|
||||
if strings.Contains(err.Error(), "Rejected send message") &&
|
||||
strings.Contains(err.Error(), "org.freedesktop.network1.Manager") {
|
||||
mainLog.Load().Warn().Msg("Interfaces are managed by systemd-networkd, switch to systemd-resolve for setting DNS")
|
||||
trySystemdResolve = true
|
||||
goto systemdResolve
|
||||
}
|
||||
// This error happens on read-only file system, which causes ctrld failed to create backup
|
||||
// for /etc/resolv.conf file. It is ok, because the DNS is still set anyway, and restore
|
||||
// DNS will fallback to use DHCP if there's no backup /etc/resolv.conf file.
|
||||
// The error format is controlled by us, so checking for error string is fine.
|
||||
// See: ../../internal/dns/direct.go:L278
|
||||
if r.Mode() == "direct" && strings.Contains(err.Error(), resolvConfBackupFailedMsg) {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
systemdResolve:
|
||||
if trySystemdResolve {
|
||||
// Stop systemd-networkd and retry setting DNS.
|
||||
if out, err := exec.Command("systemctl", "stop", "systemd-networkd").CombinedOutput(); err != nil {
|
||||
return fmt.Errorf("%s: %w", string(out), err)
|
||||
}
|
||||
args := []string{"--interface=" + iface.Name, "--set-domain=~"}
|
||||
for _, nameserver := range nameservers {
|
||||
args = append(args, "--set-dns="+nameserver)
|
||||
}
|
||||
for i := 0; i < maxSetDNSAttempts; i++ {
|
||||
if out, err := exec.Command("systemd-resolve", args...).CombinedOutput(); err != nil {
|
||||
return fmt.Errorf("%s: %w", string(out), err)
|
||||
}
|
||||
currentNS := currentDNS(iface)
|
||||
if isSubSet(nameservers, currentNS) {
|
||||
return nil
|
||||
}
|
||||
time.Sleep(time.Second)
|
||||
}
|
||||
mainLog.Load().Debug().Msg("DNS was not set for some reason")
|
||||
}
|
||||
|
||||
mainLog.Load().Debug().Str("interface", iface.Name).Msg("DNS configuration set successfully")
|
||||
return nil
|
||||
}
|
||||
|
||||
// resetDnsIgnoreUnusableInterface likes resetDNS, but return a nil error if the interface is not usable.
|
||||
func resetDnsIgnoreUnusableInterface(iface *net.Interface) error {
|
||||
return resetDNS(iface)
|
||||
}
|
||||
|
||||
func resetDNS(iface *net.Interface) (err error) {
|
||||
mainLog.Load().Debug().Str("interface", iface.Name).Msg("Resetting DNS configuration")
|
||||
|
||||
defer func() {
|
||||
if err == nil {
|
||||
return
|
||||
}
|
||||
// Start systemd-networkd if present.
|
||||
if exe, _ := exec.LookPath("/lib/systemd/systemd-networkd"); exe != "" {
|
||||
_ = exec.Command("systemctl", "start", "systemd-networkd").Run()
|
||||
}
|
||||
if r, oerr := dns.NewOSConfigurator(logf, &health.Tracker{}, &controlknobs.Knobs{}, iface.Name); oerr == nil {
|
||||
_ = r.SetDNS(dns.OSConfig{})
|
||||
if err := r.Close(); err != nil {
|
||||
mainLog.Load().Error().Err(err).Msg("Failed to rollback dns setting")
|
||||
return
|
||||
}
|
||||
err = nil
|
||||
}
|
||||
}()
|
||||
|
||||
var ns []string
|
||||
c, err := nclient4.New(iface.Name)
|
||||
if err != nil {
|
||||
return fmt.Errorf("nclient4.New: %w", err)
|
||||
}
|
||||
defer c.Close()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
|
||||
defer cancel()
|
||||
lease, err := c.Request(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("nclient4.Request: %w", err)
|
||||
}
|
||||
for _, nameserver := range lease.ACK.DNS() {
|
||||
if nameserver.Equal(net.IPv4zero) {
|
||||
continue
|
||||
}
|
||||
ns = append(ns, nameserver.String())
|
||||
}
|
||||
|
||||
// TODO(cuonglm): handle DHCPv6 properly.
|
||||
mainLog.Load().Debug().Msg("Checking for ipv6 availability")
|
||||
if ctrldnet.IPv6Available(ctx) {
|
||||
c := client6.NewClient()
|
||||
conversation, err := c.Exchange(iface.Name)
|
||||
if err != nil && !errAddrInUse(err) {
|
||||
mainLog.Load().Debug().Err(err).Msg("Could not exchange dhcpv6")
|
||||
}
|
||||
for _, packet := range conversation {
|
||||
if packet.Type() == dhcpv6.MessageTypeReply {
|
||||
msg, err := packet.GetInnerMessage()
|
||||
if err != nil {
|
||||
mainLog.Load().Debug().Err(err).Msg("Could not get inner dhcpv6 message")
|
||||
return nil
|
||||
}
|
||||
nameservers := msg.Options.DNS()
|
||||
for _, nameserver := range nameservers {
|
||||
ns = append(ns, nameserver.String())
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
mainLog.Load().Debug().Msg("IPv6 is not available")
|
||||
}
|
||||
|
||||
return ignoringEINTR(func() error {
|
||||
return setDNS(iface, ns)
|
||||
})
|
||||
}
|
||||
|
||||
// restoreDNS restores the DNS settings of the given interface.
|
||||
// this should only be executed upon turning off the ctrld service.
|
||||
func restoreDNS(iface *net.Interface) (err error) {
|
||||
return err
|
||||
}
|
||||
|
||||
func currentDNS(iface *net.Interface) []string {
|
||||
resolvconfFunc := func(_ string) []string { return ctrld.CurrentNameserversFromResolvconf() }
|
||||
for _, fn := range []getDNS{getDNSByResolvectl, getDNSBySystemdResolved, getDNSByNmcli, resolvconfFunc} {
|
||||
if ns := fn(iface.Name); len(ns) > 0 {
|
||||
return ns
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// currentStaticDNS returns the current static DNS settings of given interface.
|
||||
func currentStaticDNS(iface *net.Interface) ([]string, error) {
|
||||
return currentDNS(iface), nil
|
||||
}
|
||||
|
||||
func getDNSByResolvectl(iface string) []string {
|
||||
b, err := exec.Command("resolvectl", "dns", "-i", iface).Output()
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
parts := strings.Fields(strings.SplitN(string(b), "%", 2)[0])
|
||||
if len(parts) > 2 {
|
||||
return parts[3:]
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func getDNSBySystemdResolved(iface string) []string {
|
||||
b, err := exec.Command("systemd-resolve", "--status", iface).Output()
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
return getDNSBySystemdResolvedFromReader(bytes.NewReader(b))
|
||||
}
|
||||
|
||||
func getDNSBySystemdResolvedFromReader(r io.Reader) []string {
|
||||
scanner := bufio.NewScanner(r)
|
||||
var ret []string
|
||||
for scanner.Scan() {
|
||||
line := strings.TrimSpace(scanner.Text())
|
||||
if len(ret) > 0 {
|
||||
if net.ParseIP(line) != nil {
|
||||
ret = append(ret, line)
|
||||
}
|
||||
continue
|
||||
}
|
||||
after, found := strings.CutPrefix(line, "DNS Servers: ")
|
||||
if !found {
|
||||
continue
|
||||
}
|
||||
if net.ParseIP(after) != nil {
|
||||
ret = append(ret, after)
|
||||
}
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
func getDNSByNmcli(iface string) []string {
|
||||
b, err := exec.Command("nmcli", "dev", "show", iface).Output()
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
s := bufio.NewScanner(bytes.NewReader(b))
|
||||
var dns []string
|
||||
do := func(line string) {
|
||||
parts := strings.SplitN(line, ":", 2)
|
||||
if len(parts) > 1 {
|
||||
dns = append(dns, strings.TrimSpace(parts[1]))
|
||||
}
|
||||
}
|
||||
for s.Scan() {
|
||||
line := s.Text()
|
||||
switch {
|
||||
case strings.HasPrefix(line, "IP4.DNS"):
|
||||
fallthrough
|
||||
case strings.HasPrefix(line, "IP6.DNS"):
|
||||
do(line)
|
||||
}
|
||||
}
|
||||
return dns
|
||||
}
|
||||
|
||||
func ignoringEINTR(fn func() error) error {
|
||||
for {
|
||||
err := fn()
|
||||
if err != syscall.EINTR {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// isSubSet reports whether s2 contains all elements of s1.
|
||||
func isSubSet(s1, s2 []string) bool {
|
||||
ok := true
|
||||
for _, ns := range s1 {
|
||||
if slices.Contains(s2, ns) {
|
||||
continue
|
||||
}
|
||||
ok = false
|
||||
break
|
||||
}
|
||||
return ok
|
||||
}
|
||||
23
cmd/cli/os_linux_test.go
Normal file
23
cmd/cli/os_linux_test.go
Normal file
@@ -0,0 +1,23 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func Test_getDNSBySystemdResolvedFromReader(t *testing.T) {
|
||||
r := strings.NewReader(`Link 2 (eth0)
|
||||
Current Scopes: DNS
|
||||
LLMNR setting: yes
|
||||
MulticastDNS setting: no
|
||||
DNSSEC setting: no
|
||||
DNSSEC supported: no
|
||||
DNS Servers: 8.8.8.8
|
||||
8.8.4.4`)
|
||||
want := []string{"8.8.8.8", "8.8.4.4"}
|
||||
ns := getDNSBySystemdResolvedFromReader(r)
|
||||
if !reflect.DeepEqual(ns, want) {
|
||||
t.Logf("unexpected result, want: %v, got: %v", want, ns)
|
||||
}
|
||||
}
|
||||
13
cmd/cli/os_others.go
Normal file
13
cmd/cli/os_others.go
Normal file
@@ -0,0 +1,13 @@
|
||||
//go:build !linux && !darwin && !freebsd
|
||||
|
||||
package cli
|
||||
|
||||
// allocateIP allocates an IP address on the specified interface
|
||||
func allocateIP(ip string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// deAllocateIP deallocates an IP address from the specified interface
|
||||
func deAllocateIP(ip string) error {
|
||||
return nil
|
||||
}
|
||||
220
cmd/cli/os_windows.go
Normal file
220
cmd/cli/os_windows.go
Normal file
@@ -0,0 +1,220 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"golang.org/x/sys/windows"
|
||||
"golang.org/x/sys/windows/registry"
|
||||
"golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld"
|
||||
ctrldnet "github.com/Control-D-Inc/ctrld/internal/net"
|
||||
)
|
||||
|
||||
const (
|
||||
v4InterfaceKeyPathFormat = `SYSTEM\CurrentControlSet\Services\Tcpip\Parameters\Interfaces\`
|
||||
v6InterfaceKeyPathFormat = `SYSTEM\CurrentControlSet\Services\Tcpip6\Parameters\Interfaces\`
|
||||
)
|
||||
|
||||
// setDnsIgnoreUnusableInterface likes setDNS, but return a nil error if the interface is not usable.
|
||||
func setDnsIgnoreUnusableInterface(iface *net.Interface, nameservers []string) error {
|
||||
return setDNS(iface, nameservers)
|
||||
}
|
||||
|
||||
// setDNS sets the dns server for the provided network interface
|
||||
func setDNS(iface *net.Interface, nameservers []string) error {
|
||||
if len(nameservers) == 0 {
|
||||
return errors.New("empty DNS nameservers")
|
||||
}
|
||||
|
||||
luid, err := winipcfg.LUIDFromIndex(uint32(iface.Index))
|
||||
if err != nil {
|
||||
return fmt.Errorf("setDNS: %w", err)
|
||||
}
|
||||
var (
|
||||
serversV4 []netip.Addr
|
||||
serversV6 []netip.Addr
|
||||
)
|
||||
for _, ns := range nameservers {
|
||||
if addr, err := netip.ParseAddr(ns); err == nil {
|
||||
if addr.Is4() {
|
||||
serversV4 = append(serversV4, addr)
|
||||
} else {
|
||||
serversV6 = append(serversV6, addr)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Note that Windows won't modify the current search domains if passing nil to luid.SetDNS function.
|
||||
// searchDomains is still implemented for Windows just in case Windows API changes in future versions.
|
||||
_ = searchDomains
|
||||
|
||||
if len(serversV4) == 0 && len(serversV6) == 0 {
|
||||
return errors.New("invalid DNS nameservers")
|
||||
}
|
||||
if len(serversV4) > 0 {
|
||||
if err := luid.SetDNS(windows.AF_INET, serversV4, nil); err != nil {
|
||||
return fmt.Errorf("could not set DNS ipv4: %w", err)
|
||||
}
|
||||
}
|
||||
if len(serversV6) > 0 {
|
||||
if err := luid.SetDNS(windows.AF_INET6, serversV6, nil); err != nil {
|
||||
return fmt.Errorf("could not set DNS ipv6: %w", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// resetDnsIgnoreUnusableInterface likes resetDNS, but return a nil error if the interface is not usable.
|
||||
func resetDnsIgnoreUnusableInterface(iface *net.Interface) error {
|
||||
return resetDNS(iface)
|
||||
}
|
||||
|
||||
// resetDNS resets DNS servers for the specified interface
|
||||
func resetDNS(iface *net.Interface) error {
|
||||
luid, err := winipcfg.LUIDFromIndex(uint32(iface.Index))
|
||||
if err != nil {
|
||||
return fmt.Errorf("resetDNS: %w", err)
|
||||
}
|
||||
// Restoring DHCP settings.
|
||||
if err := luid.SetDNS(windows.AF_INET, nil, nil); err != nil {
|
||||
return fmt.Errorf("could not reset DNS ipv4: %w", err)
|
||||
}
|
||||
if err := luid.SetDNS(windows.AF_INET6, nil, nil); err != nil {
|
||||
return fmt.Errorf("could not reset DNS ipv6: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// restoreDNS restores the DNS settings of the given interface.
|
||||
// this should only be executed upon turning off the ctrld service.
|
||||
func restoreDNS(iface *net.Interface) (err error) {
|
||||
if nss := ctrld.SavedStaticNameservers(iface); len(nss) > 0 {
|
||||
v4ns := make([]string, 0, 2)
|
||||
v6ns := make([]string, 0, 2)
|
||||
for _, ns := range nss {
|
||||
if ctrldnet.IsIPv6(ns) {
|
||||
v6ns = append(v6ns, ns)
|
||||
} else {
|
||||
v4ns = append(v4ns, ns)
|
||||
}
|
||||
}
|
||||
|
||||
luid, err := winipcfg.LUIDFromIndex(uint32(iface.Index))
|
||||
if err != nil {
|
||||
return fmt.Errorf("restoreDNS: %w", err)
|
||||
}
|
||||
|
||||
if len(v4ns) > 0 {
|
||||
mainLog.Load().Debug().Msgf("Restoring IPv4 static DNS for interface %q: %v", iface.Name, v4ns)
|
||||
if err := setDNS(iface, v4ns); err != nil {
|
||||
return fmt.Errorf("restoreDNS (IPv4): %w", err)
|
||||
}
|
||||
} else {
|
||||
mainLog.Load().Debug().Msgf("Restoring IPv4 DHCP for interface %q", iface.Name)
|
||||
if err := luid.SetDNS(windows.AF_INET, nil, nil); err != nil {
|
||||
return fmt.Errorf("restoreDNS (IPv4 clear): %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if len(v6ns) > 0 {
|
||||
mainLog.Load().Debug().Msgf("Restoring IPv6 static DNS for interface %q: %v", iface.Name, v6ns)
|
||||
if err := setDNS(iface, v6ns); err != nil {
|
||||
return fmt.Errorf("restoreDNS (IPv6): %w", err)
|
||||
}
|
||||
} else {
|
||||
mainLog.Load().Debug().Msgf("Restoring IPv6 DHCP for interface %q", iface.Name)
|
||||
if err := luid.SetDNS(windows.AF_INET6, nil, nil); err != nil {
|
||||
return fmt.Errorf("restoreDNS (IPv6 clear): %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// currentDNS returns the current DNS servers for the specified interface
|
||||
func currentDNS(iface *net.Interface) []string {
|
||||
luid, err := winipcfg.LUIDFromIndex(uint32(iface.Index))
|
||||
if err != nil {
|
||||
mainLog.Load().Error().Err(err).Msg("Failed to get interface LUID")
|
||||
return nil
|
||||
}
|
||||
nameservers, err := luid.DNS()
|
||||
if err != nil {
|
||||
mainLog.Load().Error().Err(err).Msg("Failed to get interface DNS")
|
||||
return nil
|
||||
}
|
||||
ns := make([]string, 0, len(nameservers))
|
||||
for _, nameserver := range nameservers {
|
||||
ns = append(ns, nameserver.String())
|
||||
}
|
||||
return ns
|
||||
}
|
||||
|
||||
// currentStaticDNS checks both the IPv4 and IPv6 paths for static DNS values using keys
|
||||
// like "NameServer" and "ProfileNameServer".
|
||||
func currentStaticDNS(iface *net.Interface) ([]string, error) {
|
||||
luid, err := winipcfg.LUIDFromIndex(uint32(iface.Index))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("fallback winipcfg.LUIDFromIndex: %w", err)
|
||||
}
|
||||
guid, err := luid.GUID()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("fallback luid.GUID: %w", err)
|
||||
}
|
||||
|
||||
var ns []string
|
||||
keyPaths := []string{v4InterfaceKeyPathFormat, v6InterfaceKeyPathFormat}
|
||||
for _, path := range keyPaths {
|
||||
interfaceKeyPath := path + guid.String()
|
||||
k, err := registry.OpenKey(registry.LOCAL_MACHINE, interfaceKeyPath, registry.QUERY_VALUE)
|
||||
if err != nil {
|
||||
mainLog.Load().Debug().Err(err).Msgf("Failed to open registry key %q for interface %q; trying next key", interfaceKeyPath, iface.Name)
|
||||
continue
|
||||
}
|
||||
func() {
|
||||
defer k.Close()
|
||||
for _, keyName := range []string{"NameServer", "ProfileNameServer"} {
|
||||
value, _, err := k.GetStringValue(keyName)
|
||||
if err != nil && !errors.Is(err, registry.ErrNotExist) {
|
||||
mainLog.Load().Debug().Err(err).Msgf("Error reading %s registry key", keyName)
|
||||
continue
|
||||
}
|
||||
if len(value) > 0 {
|
||||
mainLog.Load().Debug().Msgf("Found static DNS for interface %q: %s", iface.Name, value)
|
||||
parsed := parseDNSServers(value)
|
||||
for _, pns := range parsed {
|
||||
if !slices.Contains(ns, pns) {
|
||||
ns = append(ns, pns)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
if len(ns) == 0 {
|
||||
mainLog.Load().Debug().Msgf("No static DNS values found for interface %q", iface.Name)
|
||||
}
|
||||
return ns, nil
|
||||
}
|
||||
|
||||
// parseDNSServers splits a DNS server string that may be comma- or space-separated,
|
||||
// and trims any extraneous whitespace or null characters.
|
||||
func parseDNSServers(val string) []string {
|
||||
fields := strings.FieldsFunc(val, func(r rune) bool {
|
||||
return r == ' ' || r == ','
|
||||
})
|
||||
var servers []string
|
||||
for _, f := range fields {
|
||||
trimmed := strings.TrimSpace(f)
|
||||
if len(trimmed) > 0 {
|
||||
servers = append(servers, trimmed)
|
||||
}
|
||||
}
|
||||
return servers
|
||||
}
|
||||
76
cmd/cli/os_windows_test.go
Normal file
76
cmd/cli/os_windows_test.go
Normal file
@@ -0,0 +1,76 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"net"
|
||||
"os/exec"
|
||||
"slices"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
|
||||
)
|
||||
|
||||
func Test_currentStaticDNS(t *testing.T) {
|
||||
iface, err := net.InterfaceByName(defaultIfaceName())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
start := time.Now()
|
||||
staticDns, err := currentStaticDNS(iface)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Logf("Using Windows API takes: %d", time.Since(start).Milliseconds())
|
||||
|
||||
start = time.Now()
|
||||
staticDnsPowershell, err := currentStaticDnsPowershell(iface)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Logf("Using Powershell takes: %d", time.Since(start).Milliseconds())
|
||||
|
||||
slices.Sort(staticDns)
|
||||
slices.Sort(staticDnsPowershell)
|
||||
if !slices.Equal(staticDns, staticDnsPowershell) {
|
||||
t.Fatalf("result mismatch, want: %v, got: %v", staticDnsPowershell, staticDns)
|
||||
}
|
||||
}
|
||||
|
||||
func currentStaticDnsPowershell(iface *net.Interface) ([]string, error) {
|
||||
luid, err := winipcfg.LUIDFromIndex(uint32(iface.Index))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
guid, err := luid.GUID()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var ns []string
|
||||
for _, path := range []string{"HKLM:\\" + v4InterfaceKeyPathFormat, "HKLM:\\" + v6InterfaceKeyPathFormat} {
|
||||
interfaceKeyPath := path + guid.String()
|
||||
found := false
|
||||
for _, key := range []string{"NameServer", "ProfileNameServer"} {
|
||||
if found {
|
||||
continue
|
||||
}
|
||||
cmd := fmt.Sprintf(`Get-ItemPropertyValue -Path "%s" -Name "%s"`, interfaceKeyPath, key)
|
||||
out, err := powershell(cmd)
|
||||
if err == nil && len(out) > 0 {
|
||||
found = true
|
||||
for _, e := range strings.Split(string(out), ",") {
|
||||
ns = append(ns, strings.TrimRight(e, "\x00"))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return ns, nil
|
||||
}
|
||||
|
||||
// powershell runs the given powershell command.
|
||||
func powershell(cmd string) ([]byte, error) {
|
||||
out, err := exec.Command("powershell", "-Command", cmd).CombinedOutput()
|
||||
return bytes.TrimSpace(out), err
|
||||
}
|
||||
1491
cmd/cli/prog.go
Normal file
1491
cmd/cli/prog.go
Normal file
File diff suppressed because it is too large
Load Diff
@@ -1,23 +1,13 @@
|
||||
package main
|
||||
package cli
|
||||
|
||||
import (
|
||||
"github.com/kardianos/service"
|
||||
)
|
||||
|
||||
func (p *prog) preRun() {
|
||||
if !service.Interactive() {
|
||||
p.setDNS()
|
||||
}
|
||||
}
|
||||
|
||||
// setDependencies sets service dependencies for Darwin
|
||||
func setDependencies(svc *service.Config) {}
|
||||
|
||||
// setWorkingDirectory sets the working directory for the service
|
||||
func setWorkingDirectory(svc *service.Config, dir string) {
|
||||
svc.WorkingDirectory = dir
|
||||
}
|
||||
|
||||
func (p *prog) preStop() {
|
||||
if !service.Interactive() {
|
||||
p.resetDNS()
|
||||
}
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package main
|
||||
package cli
|
||||
|
||||
import (
|
||||
"os"
|
||||
@@ -6,17 +6,11 @@ import (
|
||||
"github.com/kardianos/service"
|
||||
)
|
||||
|
||||
func (p *prog) preRun() {
|
||||
if !service.Interactive() {
|
||||
p.setDNS()
|
||||
}
|
||||
}
|
||||
|
||||
// setDependencies sets service dependencies for FreeBSD
|
||||
func setDependencies(svc *service.Config) {
|
||||
// TODO(cuonglm): remove once https://github.com/kardianos/service/issues/359 fixed.
|
||||
_ = os.MkdirAll("/usr/local/etc/rc.d", 0755)
|
||||
}
|
||||
|
||||
// setWorkingDirectory sets the working directory for the service
|
||||
func setWorkingDirectory(svc *service.Config, dir string) {}
|
||||
|
||||
func (p *prog) preStop() {}
|
||||
62
cmd/cli/prog_linux.go
Normal file
62
cmd/cli/prog_linux.go
Normal file
@@ -0,0 +1,62 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"io"
|
||||
"os"
|
||||
"os/exec"
|
||||
"strings"
|
||||
|
||||
"github.com/kardianos/service"
|
||||
)
|
||||
|
||||
func init() {
|
||||
if r, err := newLoopbackOSConfigurator(); err == nil {
|
||||
useSystemdResolved = r.Mode() == "systemd-resolved"
|
||||
}
|
||||
// Disable quic-go's ECN support by default, see https://github.com/quic-go/quic-go/issues/3911
|
||||
if os.Getenv("QUIC_GO_DISABLE_ECN") == "" {
|
||||
os.Setenv("QUIC_GO_DISABLE_ECN", "true")
|
||||
}
|
||||
}
|
||||
|
||||
// setDependencies sets service dependencies for Linux
|
||||
func setDependencies(svc *service.Config) {
|
||||
svc.Dependencies = []string{
|
||||
"Wants=network-online.target",
|
||||
"After=network-online.target",
|
||||
"Wants=NetworkManager-wait-online.service",
|
||||
"After=NetworkManager-wait-online.service",
|
||||
"Wants=nss-lookup.target",
|
||||
"After=nss-lookup.target",
|
||||
}
|
||||
if out, _ := exec.Command("networkctl", "--no-pager").CombinedOutput(); len(out) > 0 {
|
||||
if wantsSystemDNetworkdWaitOnline(bytes.NewReader(out)) {
|
||||
svc.Dependencies = append(svc.Dependencies, "Wants=systemd-networkd-wait-online.service")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// setWorkingDirectory sets the working directory for the service
|
||||
func setWorkingDirectory(svc *service.Config, dir string) {
|
||||
svc.WorkingDirectory = dir
|
||||
}
|
||||
|
||||
// wantsSystemDNetworkdWaitOnline reports whether "systemd-networkd-wait-online" service
|
||||
// is required to be added to ctrld dependencies services.
|
||||
// The input reader r is the output of "networkctl --no-pager" command.
|
||||
func wantsSystemDNetworkdWaitOnline(r io.Reader) bool {
|
||||
scanner := bufio.NewScanner(r)
|
||||
// Skip header
|
||||
scanner.Scan()
|
||||
configured := false
|
||||
for scanner.Scan() {
|
||||
fields := strings.Fields(scanner.Text())
|
||||
if len(fields) > 0 && fields[len(fields)-1] == "configured" {
|
||||
configured = true
|
||||
break
|
||||
}
|
||||
}
|
||||
return configured
|
||||
}
|
||||
48
cmd/cli/prog_linux_test.go
Normal file
48
cmd/cli/prog_linux_test.go
Normal file
@@ -0,0 +1,48 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"io"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
const (
|
||||
networkctlUnmanagedOutput = `IDX LINK TYPE OPERATIONAL SETUP
|
||||
1 lo loopback carrier unmanaged
|
||||
2 wlp0s20f3 wlan routable unmanaged
|
||||
3 tailscale0 none routable unmanaged
|
||||
4 br-9ac33145e060 bridge no-carrier unmanaged
|
||||
5 docker0 bridge no-carrier unmanaged
|
||||
|
||||
5 links listed.
|
||||
`
|
||||
networkctlManagedOutput = `IDX LINK TYPE OPERATIONAL SETUP
|
||||
1 lo loopback carrier unmanaged
|
||||
2 wlp0s20f3 wlan routable configured
|
||||
3 tailscale0 none routable unmanaged
|
||||
4 br-9ac33145e060 bridge no-carrier unmanaged
|
||||
5 docker0 bridge no-carrier unmanaged
|
||||
|
||||
5 links listed.
|
||||
`
|
||||
)
|
||||
|
||||
func Test_wantsSystemDNetworkdWaitOnline(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
r io.Reader
|
||||
required bool
|
||||
}{
|
||||
{"unmanaged", strings.NewReader(networkctlUnmanagedOutput), false},
|
||||
{"managed", strings.NewReader(networkctlManagedOutput), true},
|
||||
{"empty", strings.NewReader(""), false},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
if required := wantsSystemDNetworkdWaitOnline(tc.r); required != tc.required {
|
||||
t.Errorf("wants %v got %v", tc.required, required)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
33
cmd/cli/prog_log.go
Normal file
33
cmd/cli/prog_log.go
Normal file
@@ -0,0 +1,33 @@
|
||||
package cli
|
||||
|
||||
import "github.com/Control-D-Inc/ctrld"
|
||||
|
||||
// Debug starts a new message with debug level.
|
||||
func (p *prog) Debug() *ctrld.LogEvent {
|
||||
return p.logger.Load().Debug()
|
||||
}
|
||||
|
||||
// Warn starts a new message with warn level.
|
||||
func (p *prog) Warn() *ctrld.LogEvent {
|
||||
return p.logger.Load().Warn()
|
||||
}
|
||||
|
||||
// Info starts a new message with info level.
|
||||
func (p *prog) Info() *ctrld.LogEvent {
|
||||
return p.logger.Load().Info()
|
||||
}
|
||||
|
||||
// Fatal starts a new message with fatal level.
|
||||
func (p *prog) Fatal() *ctrld.LogEvent {
|
||||
return p.logger.Load().Fatal()
|
||||
}
|
||||
|
||||
// Error starts a new message with error level.
|
||||
func (p *prog) Error() *ctrld.LogEvent {
|
||||
return p.logger.Load().Error()
|
||||
}
|
||||
|
||||
// Notice starts a new message with notice level.
|
||||
func (p *prog) Notice() *ctrld.LogEvent {
|
||||
return p.logger.Load().Notice()
|
||||
}
|
||||
14
cmd/cli/prog_others.go
Normal file
14
cmd/cli/prog_others.go
Normal file
@@ -0,0 +1,14 @@
|
||||
//go:build !linux && !freebsd && !darwin && !windows
|
||||
|
||||
package cli
|
||||
|
||||
import "github.com/kardianos/service"
|
||||
|
||||
// setDependencies sets service dependencies for other platforms
|
||||
func setDependencies(svc *service.Config) {}
|
||||
|
||||
// setWorkingDirectory sets the working directory for the service
|
||||
func setWorkingDirectory(svc *service.Config, dir string) {
|
||||
// WorkingDirectory is not supported on Windows.
|
||||
svc.WorkingDirectory = dir
|
||||
}
|
||||
266
cmd/cli/prog_test.go
Normal file
266
cmd/cli/prog_test.go
Normal file
@@ -0,0 +1,266 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Masterminds/semver/v3"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld"
|
||||
)
|
||||
|
||||
func Test_prog_dnsWatchdogEnabled(t *testing.T) {
|
||||
p := &prog{cfg: &ctrld.Config{}}
|
||||
|
||||
// Default value is true.
|
||||
assert.True(t, p.dnsWatchdogEnabled())
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
enabled bool
|
||||
}{
|
||||
{"enabled", true},
|
||||
{"disabled", false},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
p.cfg.Service.DnsWatchdogEnabled = &tc.enabled
|
||||
assert.Equal(t, tc.enabled, p.dnsWatchdogEnabled())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_prog_dnsWatchdogInterval(t *testing.T) {
|
||||
p := &prog{cfg: &ctrld.Config{}}
|
||||
|
||||
// Default value is 20s.
|
||||
assert.Equal(t, dnsWatchdogDefaultInterval, p.dnsWatchdogDuration())
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
duration time.Duration
|
||||
expected time.Duration
|
||||
}{
|
||||
{"valid", time.Minute, time.Minute},
|
||||
{"zero", 0, dnsWatchdogDefaultInterval},
|
||||
{"nagative", time.Duration(-1 * time.Minute), dnsWatchdogDefaultInterval},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
p.cfg.Service.DnsWatchdogInvterval = &tc.duration
|
||||
assert.Equal(t, tc.expected, p.dnsWatchdogDuration())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_shouldUpgrade(t *testing.T) {
|
||||
// Helper function to create a version
|
||||
makeVersion := func(v string) *semver.Version {
|
||||
ver, err := semver.NewVersion(v)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create version %s: %v", v, err)
|
||||
}
|
||||
return ver
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
versionTarget string
|
||||
currentVersion *semver.Version
|
||||
shouldUpgrade bool
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "empty version target",
|
||||
versionTarget: "",
|
||||
currentVersion: makeVersion("v1.0.0"),
|
||||
shouldUpgrade: false,
|
||||
description: "should skip upgrade when version target is empty",
|
||||
},
|
||||
{
|
||||
name: "invalid version target",
|
||||
versionTarget: "invalid-version",
|
||||
currentVersion: makeVersion("v1.0.0"),
|
||||
shouldUpgrade: false,
|
||||
description: "should skip upgrade when version target is invalid",
|
||||
},
|
||||
{
|
||||
name: "same version",
|
||||
versionTarget: "v1.0.0",
|
||||
currentVersion: makeVersion("v1.0.0"),
|
||||
shouldUpgrade: false,
|
||||
description: "should skip upgrade when target version equals current version",
|
||||
},
|
||||
{
|
||||
name: "older version",
|
||||
versionTarget: "v1.0.0",
|
||||
currentVersion: makeVersion("v1.1.0"),
|
||||
shouldUpgrade: false,
|
||||
description: "should skip upgrade when target version is older than current version",
|
||||
},
|
||||
{
|
||||
name: "patch upgrade allowed",
|
||||
versionTarget: "v1.0.1",
|
||||
currentVersion: makeVersion("v1.0.0"),
|
||||
shouldUpgrade: true,
|
||||
description: "should allow patch version upgrade within same major version",
|
||||
},
|
||||
{
|
||||
name: "minor upgrade allowed",
|
||||
versionTarget: "v1.1.0",
|
||||
currentVersion: makeVersion("v1.0.0"),
|
||||
shouldUpgrade: true,
|
||||
description: "should allow minor version upgrade within same major version",
|
||||
},
|
||||
{
|
||||
name: "major upgrade blocked",
|
||||
versionTarget: "v2.0.0",
|
||||
currentVersion: makeVersion("v1.0.0"),
|
||||
shouldUpgrade: false,
|
||||
description: "should block major version upgrade",
|
||||
},
|
||||
{
|
||||
name: "major downgrade blocked",
|
||||
versionTarget: "v1.0.0",
|
||||
currentVersion: makeVersion("v2.0.0"),
|
||||
shouldUpgrade: false,
|
||||
description: "should block major version downgrade",
|
||||
},
|
||||
{
|
||||
name: "version without v prefix",
|
||||
versionTarget: "1.0.1",
|
||||
currentVersion: makeVersion("v1.0.0"),
|
||||
shouldUpgrade: true,
|
||||
description: "should handle version target without v prefix",
|
||||
},
|
||||
{
|
||||
name: "complex version upgrade allowed",
|
||||
versionTarget: "v1.5.3",
|
||||
currentVersion: makeVersion("v1.4.2"),
|
||||
shouldUpgrade: true,
|
||||
description: "should allow complex version upgrade within same major version",
|
||||
},
|
||||
{
|
||||
name: "complex major upgrade blocked",
|
||||
versionTarget: "v3.1.0",
|
||||
currentVersion: makeVersion("v2.5.3"),
|
||||
shouldUpgrade: false,
|
||||
description: "should block complex major version upgrade",
|
||||
},
|
||||
{
|
||||
name: "pre-release version upgrade allowed",
|
||||
versionTarget: "v1.0.1-beta.1",
|
||||
currentVersion: makeVersion("v1.0.0"),
|
||||
shouldUpgrade: true,
|
||||
description: "should allow pre-release version upgrade within same major version",
|
||||
},
|
||||
{
|
||||
name: "pre-release major upgrade blocked",
|
||||
versionTarget: "v2.0.0-alpha.1",
|
||||
currentVersion: makeVersion("v1.0.0"),
|
||||
shouldUpgrade: false,
|
||||
description: "should block pre-release major version upgrade",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// Create test logger
|
||||
testLogger := &ctrld.Logger{Logger: zap.NewNop()}
|
||||
|
||||
// Call the function and capture the result
|
||||
result := shouldUpgrade(tc.versionTarget, tc.currentVersion, testLogger)
|
||||
|
||||
// Assert the expected result
|
||||
assert.Equal(t, tc.shouldUpgrade, result, tc.description)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_selfUpgradeCheck(t *testing.T) {
|
||||
// Helper function to create a version
|
||||
makeVersion := func(v string) *semver.Version {
|
||||
ver, err := semver.NewVersion(v)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create version %s: %v", v, err)
|
||||
}
|
||||
return ver
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
versionTarget string
|
||||
currentVersion *semver.Version
|
||||
shouldUpgrade bool
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "upgrade allowed",
|
||||
versionTarget: "v1.0.1",
|
||||
currentVersion: makeVersion("v1.0.0"),
|
||||
shouldUpgrade: true,
|
||||
description: "should allow upgrade and attempt to perform it",
|
||||
},
|
||||
{
|
||||
name: "upgrade blocked",
|
||||
versionTarget: "v2.0.0",
|
||||
currentVersion: makeVersion("v1.0.0"),
|
||||
shouldUpgrade: false,
|
||||
description: "should block upgrade and not attempt to perform it",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// Create test logger
|
||||
testLogger := &ctrld.Logger{Logger: zap.NewNop()}
|
||||
|
||||
// Call the function and capture the result
|
||||
result := selfUpgradeCheck(tc.versionTarget, tc.currentVersion, testLogger)
|
||||
|
||||
// Assert the expected result
|
||||
assert.Equal(t, tc.shouldUpgrade, result, tc.description)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_performUpgrade(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
versionTarget string
|
||||
expectedResult bool
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "valid version target",
|
||||
versionTarget: "v1.0.1",
|
||||
expectedResult: true,
|
||||
description: "should attempt to perform upgrade with valid version target",
|
||||
},
|
||||
{
|
||||
name: "empty version target",
|
||||
versionTarget: "",
|
||||
expectedResult: true,
|
||||
description: "should attempt to perform upgrade even with empty version target",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// Create test logger
|
||||
testLogger := &ctrld.Logger{Logger: zap.NewNop()}
|
||||
// Call the function and capture the result
|
||||
result := performUpgrade(tc.versionTarget, testLogger)
|
||||
assert.Equal(t, tc.expectedResult, result, tc.description)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,16 +1,12 @@
|
||||
//go:build !linux && !freebsd && !darwin
|
||||
|
||||
package main
|
||||
package cli
|
||||
|
||||
import "github.com/kardianos/service"
|
||||
|
||||
func (p *prog) preRun() {}
|
||||
|
||||
// setDependencies sets service dependencies for Windows
|
||||
func setDependencies(svc *service.Config) {}
|
||||
|
||||
// setWorkingDirectory sets the working directory for the service
|
||||
func setWorkingDirectory(svc *service.Config, dir string) {
|
||||
// WorkingDirectory is not supported on Windows.
|
||||
svc.WorkingDirectory = dir
|
||||
}
|
||||
|
||||
func (p *prog) preStop() {}
|
||||
66
cmd/cli/prometheus.go
Normal file
66
cmd/cli/prometheus.go
Normal file
@@ -0,0 +1,66 @@
|
||||
package cli
|
||||
|
||||
import "github.com/prometheus/client_golang/prometheus"
|
||||
|
||||
// Prometheus metrics label constants for consistent labeling across all metrics
|
||||
// These ensure standardized metric labeling for monitoring and alerting
|
||||
const (
|
||||
metricsLabelListener = "listener"
|
||||
metricsLabelClientSourceIP = "client_source_ip"
|
||||
metricsLabelClientMac = "client_mac"
|
||||
metricsLabelClientHostname = "client_hostname"
|
||||
metricsLabelUpstream = "upstream"
|
||||
metricsLabelRRType = "rr_type"
|
||||
metricsLabelRCode = "rcode"
|
||||
)
|
||||
|
||||
// statsVersion represent ctrld version.
|
||||
// This metric provides version information for monitoring and debugging
|
||||
var statsVersion = prometheus.NewCounterVec(prometheus.CounterOpts{
|
||||
Name: "ctrld_build_info",
|
||||
Help: "Version of ctrld process.",
|
||||
}, []string{"gitref", "goversion", "version"})
|
||||
|
||||
// statsTimeStart represents start time of ctrld service.
|
||||
// This metric tracks service uptime and helps with monitoring service restarts
|
||||
var statsTimeStart = prometheus.NewGauge(prometheus.GaugeOpts{
|
||||
Name: "ctrld_time_seconds",
|
||||
Help: "Start time of the ctrld process since unix epoch in seconds.",
|
||||
})
|
||||
|
||||
// statsQueriesCountLabels defines the labels for query count metrics
|
||||
// These labels provide detailed breakdown of DNS query statistics
|
||||
var statsQueriesCountLabels = []string{
|
||||
metricsLabelListener,
|
||||
metricsLabelClientSourceIP,
|
||||
metricsLabelClientMac,
|
||||
metricsLabelClientHostname,
|
||||
metricsLabelUpstream,
|
||||
metricsLabelRRType,
|
||||
metricsLabelRCode,
|
||||
}
|
||||
|
||||
// statsQueriesCount counts total number of queries.
|
||||
// This provides comprehensive DNS query statistics for monitoring and alerting
|
||||
var statsQueriesCount = prometheus.NewCounterVec(prometheus.CounterOpts{
|
||||
Name: "ctrld_queries_count",
|
||||
Help: "Total number of queries.",
|
||||
}, statsQueriesCountLabels)
|
||||
|
||||
// statsClientQueriesCount counts total number of queries of a client.
|
||||
//
|
||||
// The labels "client_source_ip", "client_mac", "client_hostname" are unbounded,
|
||||
// thus this stat is highly inefficient if there are many devices.
|
||||
// This metric should be used carefully in high-client environments
|
||||
var statsClientQueriesCount = prometheus.NewCounterVec(prometheus.CounterOpts{
|
||||
Name: "ctrld_client_queries_count",
|
||||
Help: "Total number queries of a client.",
|
||||
}, []string{metricsLabelClientSourceIP, metricsLabelClientMac, metricsLabelClientHostname})
|
||||
|
||||
// WithLabelValuesInc increases prometheus counter by 1 if query stats is enabled.
|
||||
// This provides conditional metric collection to avoid performance impact when metrics are disabled
|
||||
func (p *prog) WithLabelValuesInc(c *prometheus.CounterVec, lvs ...string) {
|
||||
if p.metricsQueryStats.Load() {
|
||||
c.WithLabelValues(lvs...).Inc()
|
||||
}
|
||||
}
|
||||
19
cmd/cli/reload_others.go
Normal file
19
cmd/cli/reload_others.go
Normal file
@@ -0,0 +1,19 @@
|
||||
//go:build !windows
|
||||
|
||||
package cli
|
||||
|
||||
import (
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
// notifyReloadSigCh sends reload signal to the channel
|
||||
func notifyReloadSigCh(ch chan os.Signal) {
|
||||
signal.Notify(ch, syscall.SIGUSR1)
|
||||
}
|
||||
|
||||
// sendReloadSignal sends a reload signal to the current process
|
||||
func (p *prog) sendReloadSignal() error {
|
||||
return syscall.Kill(syscall.Getpid(), syscall.SIGUSR1)
|
||||
}
|
||||
20
cmd/cli/reload_windows.go
Normal file
20
cmd/cli/reload_windows.go
Normal file
@@ -0,0 +1,20 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"os"
|
||||
"time"
|
||||
)
|
||||
|
||||
// notifyReloadSigCh is a no-op on Windows platforms
|
||||
func notifyReloadSigCh(ch chan os.Signal) {}
|
||||
|
||||
// sendReloadSignal sends a reload signal to the program
|
||||
func (p *prog) sendReloadSignal() error {
|
||||
select {
|
||||
case p.reloadCh <- struct{}{}:
|
||||
return nil
|
||||
case <-time.After(5 * time.Second):
|
||||
}
|
||||
return errors.New("timeout while sending reload signal")
|
||||
}
|
||||
153
cmd/cli/resolvconf.go
Normal file
153
cmd/cli/resolvconf.go
Normal file
@@ -0,0 +1,153 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"github.com/fsnotify/fsnotify"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld/internal/resolvconffile"
|
||||
)
|
||||
|
||||
// parseResolvConfNameservers reads the resolv.conf file and returns the nameservers found.
|
||||
// Returns nil if no nameservers are found.
|
||||
// This function parses the system DNS configuration to understand current nameserver settings
|
||||
func (p *prog) parseResolvConfNameservers(path string) ([]string, error) {
|
||||
return resolvconffile.NameserversFromFile(path)
|
||||
}
|
||||
|
||||
// watchResolvConf watches any changes to /etc/resolv.conf file,
|
||||
// and reverting to the original config set by ctrld.
|
||||
// This ensures that DNS settings are not overridden by other applications or system processes
|
||||
func (p *prog) watchResolvConf(iface *net.Interface, ns []netip.Addr, setDnsFn func(iface *net.Interface, ns []netip.Addr) error) {
|
||||
resolvConfPath := "/etc/resolv.conf"
|
||||
// Evaluating symbolics link to watch the target file that /etc/resolv.conf point to.
|
||||
// This handles systems where resolv.conf is a symlink to another location
|
||||
if rp, _ := filepath.EvalSymlinks(resolvConfPath); rp != "" {
|
||||
resolvConfPath = rp
|
||||
}
|
||||
p.Debug().Msgf("Start watching %s file", resolvConfPath)
|
||||
watcher, err := fsnotify.NewWatcher()
|
||||
if err != nil {
|
||||
p.Warn().Err(err).Msg("Could not create watcher for /etc/resolv.conf")
|
||||
return
|
||||
}
|
||||
defer watcher.Close()
|
||||
|
||||
// We watch /etc instead of /etc/resolv.conf directly,
|
||||
// see: https://github.com/fsnotify/fsnotify#watching-a-file-doesnt-work-well
|
||||
// This is necessary because some systems don't properly notify on file changes
|
||||
watchDir := filepath.Dir(resolvConfPath)
|
||||
if err := watcher.Add(watchDir); err != nil {
|
||||
p.Warn().Err(err).Msgf("Could not add %s to watcher list", watchDir)
|
||||
return
|
||||
}
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-p.dnsWatcherStopCh:
|
||||
return
|
||||
case <-p.stopCh:
|
||||
p.Debug().Msgf("Stopping watcher for %s", resolvConfPath)
|
||||
return
|
||||
case event, ok := <-watcher.Events:
|
||||
if p.recoveryRunning.Load() {
|
||||
return
|
||||
}
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if event.Name != resolvConfPath { // skip if not /etc/resolv.conf changes.
|
||||
continue
|
||||
}
|
||||
if event.Has(fsnotify.Write) || event.Has(fsnotify.Create) {
|
||||
p.Debug().Msgf("/etc/resolv.conf changes detected, reading changes...")
|
||||
|
||||
// Convert expected nameservers to strings for comparison
|
||||
// This allows us to detect when the resolv.conf has been modified
|
||||
expectedNS := make([]string, len(ns))
|
||||
for i, addr := range ns {
|
||||
expectedNS[i] = addr.String()
|
||||
}
|
||||
|
||||
var foundNS []string
|
||||
var err error
|
||||
|
||||
maxRetries := 1
|
||||
for retry := 0; retry < maxRetries; retry++ {
|
||||
foundNS, err = p.parseResolvConfNameservers(resolvConfPath)
|
||||
if err != nil {
|
||||
p.Error().Err(err).Msg("Failed to read resolv.conf content")
|
||||
break
|
||||
}
|
||||
|
||||
// If we found nameservers, break out of retry loop
|
||||
// This handles cases where the file is being written but not yet complete
|
||||
if len(foundNS) > 0 {
|
||||
break
|
||||
}
|
||||
|
||||
// Only retry if we found no nameservers
|
||||
// This handles temporary file states during updates
|
||||
if retry < maxRetries-1 {
|
||||
p.Debug().Msgf("resolv.conf has no nameserver entries, retry %d/%d in 2 seconds", retry+1, maxRetries)
|
||||
select {
|
||||
case <-p.stopCh:
|
||||
return
|
||||
case <-p.dnsWatcherStopCh:
|
||||
return
|
||||
case <-time.After(2 * time.Second):
|
||||
continue
|
||||
}
|
||||
} else {
|
||||
p.Debug().Msg("resolv.conf remained empty after all retries")
|
||||
}
|
||||
}
|
||||
|
||||
// If we found nameservers, check if they match what we expect
|
||||
if len(foundNS) > 0 {
|
||||
// Check if the nameservers match exactly what we expect
|
||||
matches := len(foundNS) == len(expectedNS)
|
||||
if matches {
|
||||
for i := range foundNS {
|
||||
if foundNS[i] != expectedNS[i] {
|
||||
matches = false
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
p.Debug().
|
||||
Strs("found", foundNS).
|
||||
Strs("expected", expectedNS).
|
||||
Bool("matches", matches).
|
||||
Msg("checking nameservers")
|
||||
|
||||
// Only revert if the nameservers don't match
|
||||
if !matches {
|
||||
if err := watcher.Remove(watchDir); err != nil {
|
||||
p.Error().Err(err).Msg("Failed to pause watcher")
|
||||
continue
|
||||
}
|
||||
|
||||
if err := setDnsFn(iface, ns); err != nil {
|
||||
p.Error().Err(err).Msg("Failed to revert /etc/resolv.conf changes")
|
||||
}
|
||||
|
||||
if err := watcher.Add(watchDir); err != nil {
|
||||
p.Error().Err(err).Msg("Failed to continue running watcher")
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
case err, ok := <-watcher.Errors:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
p.Error().Err(err).Msg("Could not get event for /etc/resolv.conf")
|
||||
}
|
||||
}
|
||||
}
|
||||
49
cmd/cli/resolvconf_darwin.go
Normal file
49
cmd/cli/resolvconf_darwin.go
Normal file
@@ -0,0 +1,49 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"slices"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld/internal/dns/resolvconffile"
|
||||
)
|
||||
|
||||
const resolvConfPath = "/etc/resolv.conf"
|
||||
|
||||
// setResolvConf sets the content of resolv.conf file using the given nameservers list.
|
||||
func (p *prog) setResolvConf(iface *net.Interface, ns []netip.Addr) error {
|
||||
servers := make([]string, len(ns))
|
||||
for i := range ns {
|
||||
servers[i] = ns[i].String()
|
||||
}
|
||||
if err := setDNS(iface, servers); err != nil {
|
||||
return err
|
||||
}
|
||||
slices.Sort(servers)
|
||||
curNs := currentDNS(iface)
|
||||
slices.Sort(curNs)
|
||||
if !slices.Equal(curNs, servers) {
|
||||
c, err := resolvconffile.ParseFile(resolvConfPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.Nameservers = ns
|
||||
f, err := os.Create(resolvConfPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
if err := c.Write(f); err != nil {
|
||||
return err
|
||||
}
|
||||
return f.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// shouldWatchResolvconf reports whether ctrld should watch changes to resolv.conf file with given OS configurator.
|
||||
func shouldWatchResolvconf() bool {
|
||||
return true
|
||||
}
|
||||
52
cmd/cli/resolvconf_not_darwin_unix.go
Normal file
52
cmd/cli/resolvconf_not_darwin_unix.go
Normal file
@@ -0,0 +1,52 @@
|
||||
//go:build unix && !darwin
|
||||
|
||||
package cli
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
|
||||
"tailscale.com/control/controlknobs"
|
||||
"tailscale.com/health"
|
||||
"tailscale.com/util/dnsname"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld/internal/dns"
|
||||
)
|
||||
|
||||
// setResolvConf sets the content of the resolv.conf file using the given nameservers list.
|
||||
func (p *prog) setResolvConf(iface *net.Interface, ns []netip.Addr) error {
|
||||
r, err := newLoopbackOSConfigurator()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
oc := dns.OSConfig{
|
||||
Nameservers: ns,
|
||||
SearchDomains: []dnsname.FQDN{},
|
||||
}
|
||||
if sds, err := searchDomains(); err == nil {
|
||||
oc.SearchDomains = sds
|
||||
} else {
|
||||
p.Debug().Err(err).Msg("Failed to get search domains list when reverting resolv.conf file")
|
||||
}
|
||||
return r.SetDNS(oc)
|
||||
}
|
||||
|
||||
// shouldWatchResolvconf reports whether ctrld should watch changes to resolv.conf file with given OS configurator.
|
||||
func shouldWatchResolvconf() bool {
|
||||
r, err := newLoopbackOSConfigurator()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
switch r.Mode() {
|
||||
case "direct", "resolvconf":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// newLoopbackOSConfigurator creates an OSConfigurator for DNS management using the "lo" interface.
|
||||
func newLoopbackOSConfigurator() (dns.OSConfigurator, error) {
|
||||
return dns.NewOSConfigurator(noopLogf, &health.Tracker{}, &controlknobs.Knobs{}, "lo")
|
||||
}
|
||||
51
cmd/cli/resolvconf_test.go
Normal file
51
cmd/cli/resolvconf_test.go
Normal file
@@ -0,0 +1,51 @@
|
||||
//go:build unix
|
||||
|
||||
package cli
|
||||
|
||||
import (
|
||||
"os"
|
||||
"slices"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld/internal/dns/resolvconffile"
|
||||
)
|
||||
|
||||
func oldParseResolvConfNameservers(path string) ([]string, error) {
|
||||
content, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Parse the file for "nameserver" lines
|
||||
var currentNS []string
|
||||
lines := strings.Split(string(content), "\n")
|
||||
for _, line := range lines {
|
||||
trimmed := strings.TrimSpace(line)
|
||||
if strings.HasPrefix(trimmed, "nameserver") {
|
||||
parts := strings.Fields(trimmed)
|
||||
if len(parts) >= 2 {
|
||||
currentNS = append(currentNS, parts[1])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return currentNS, nil
|
||||
}
|
||||
|
||||
// Test_prog_parseResolvConfNameservers tests the parsing of nameservers from resolv.conf content.
|
||||
// Note: The previous implementation was removed to reduce code duplication and consolidate
|
||||
// the resolv.conf handling logic into a single unified approach. All resolv.conf parsing
|
||||
// is now handled by the resolvconffile package, which provides a consistent interface
|
||||
// for both reading and modifying resolv.conf files across different platforms.
|
||||
func Test_prog_parseResolvConfNameservers(t *testing.T) {
|
||||
oldNss, _ := oldParseResolvConfNameservers(resolvconffile.Path)
|
||||
p := &prog{}
|
||||
nss, _ := p.parseResolvConfNameservers(resolvconffile.Path)
|
||||
slices.Sort(oldNss)
|
||||
slices.Sort(nss)
|
||||
if !slices.Equal(oldNss, nss) {
|
||||
t.Errorf("result mismatched, old: %v, new: %v", oldNss, nss)
|
||||
}
|
||||
t.Logf("result: %v", nss)
|
||||
}
|
||||
16
cmd/cli/resolvconf_windows.go
Normal file
16
cmd/cli/resolvconf_windows.go
Normal file
@@ -0,0 +1,16 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
)
|
||||
|
||||
// setResolvConf sets the content of resolv.conf file using the given nameservers list.
|
||||
func (p *prog) setResolvConf(_ *net.Interface, _ []netip.Addr) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// shouldWatchResolvconf reports whether ctrld should watch changes to resolv.conf file with given OS configurator.
|
||||
func shouldWatchResolvconf() bool {
|
||||
return false
|
||||
}
|
||||
14
cmd/cli/search_domains_unix.go
Normal file
14
cmd/cli/search_domains_unix.go
Normal file
@@ -0,0 +1,14 @@
|
||||
//go:build unix
|
||||
|
||||
package cli
|
||||
|
||||
import (
|
||||
"tailscale.com/util/dnsname"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld/internal/resolvconffile"
|
||||
)
|
||||
|
||||
// searchDomains returns the current search domains config.
|
||||
func searchDomains() ([]dnsname.FQDN, error) {
|
||||
return resolvconffile.SearchDomains()
|
||||
}
|
||||
43
cmd/cli/search_domains_windows.go
Normal file
43
cmd/cli/search_domains_windows.go
Normal file
@@ -0,0 +1,43 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"syscall"
|
||||
|
||||
"golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
|
||||
"tailscale.com/util/dnsname"
|
||||
)
|
||||
|
||||
// searchDomains returns the current search domains config.
|
||||
func searchDomains() ([]dnsname.FQDN, error) {
|
||||
flags := winipcfg.GAAFlagIncludeGateways |
|
||||
winipcfg.GAAFlagIncludePrefix
|
||||
|
||||
aas, err := winipcfg.GetAdaptersAddresses(syscall.AF_UNSPEC, flags)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("winipcfg.GetAdaptersAddresses: %w", err)
|
||||
}
|
||||
|
||||
var sds []dnsname.FQDN
|
||||
for _, aa := range aas {
|
||||
if aa.OperStatus != winipcfg.IfOperStatusUp {
|
||||
continue
|
||||
}
|
||||
|
||||
// Skip if software loopback or other non-physical types
|
||||
// This is to avoid the "Loopback Pseudo-Interface 1" issue we see on windows
|
||||
if aa.IfType == winipcfg.IfTypeSoftwareLoopback {
|
||||
continue
|
||||
}
|
||||
|
||||
for a := aa.FirstDNSSuffix; a != nil; a = a.Next {
|
||||
d, err := dnsname.ToFQDN(a.String())
|
||||
if err != nil {
|
||||
mainLog.Load().Debug().Err(err).Msgf("Failed to parse domain: %s", a.String())
|
||||
continue
|
||||
}
|
||||
sds = append(sds, d)
|
||||
}
|
||||
}
|
||||
return sds, nil
|
||||
}
|
||||
8
cmd/cli/self_delete_others.go
Normal file
8
cmd/cli/self_delete_others.go
Normal file
@@ -0,0 +1,8 @@
|
||||
//go:build !windows
|
||||
|
||||
package cli
|
||||
|
||||
var supportedSelfDelete = true
|
||||
|
||||
// selfDeleteExe performs self-deletion on non-Windows platforms
|
||||
func selfDeleteExe() error { return nil }
|
||||
138
cmd/cli/self_delete_windows.go
Normal file
138
cmd/cli/self_delete_windows.go
Normal file
@@ -0,0 +1,138 @@
|
||||
// Copied from https://github.com/secur30nly/go-self-delete
|
||||
// with modification to suitable for ctrld usage.
|
||||
|
||||
/*
|
||||
License: MIT Licence
|
||||
|
||||
References:
|
||||
- https://github.com/LloydLabs/delete-self-poc
|
||||
- https://twitter.com/jonasLyk/status/1350401461985955840
|
||||
*/
|
||||
|
||||
package cli
|
||||
|
||||
import (
|
||||
"unsafe"
|
||||
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
var supportedSelfDelete = false
|
||||
|
||||
type FILE_RENAME_INFO struct {
|
||||
Union struct {
|
||||
ReplaceIfExists bool
|
||||
Flags uint32
|
||||
}
|
||||
RootDirectory windows.Handle
|
||||
FileNameLength uint32
|
||||
FileName [1]uint16
|
||||
}
|
||||
|
||||
type FILE_DISPOSITION_INFO struct {
|
||||
DeleteFile bool
|
||||
}
|
||||
|
||||
// dsOpenHandle opens a handle to the specified file with DELETE access
|
||||
func dsOpenHandle(pwPath *uint16) (windows.Handle, error) {
|
||||
handle, err := windows.CreateFile(
|
||||
pwPath,
|
||||
windows.DELETE,
|
||||
0,
|
||||
nil,
|
||||
windows.OPEN_EXISTING,
|
||||
windows.FILE_ATTRIBUTE_NORMAL,
|
||||
0,
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return handle, nil
|
||||
}
|
||||
|
||||
// dsRenameHandle renames a file handle to a stream name
|
||||
func dsRenameHandle(hHandle windows.Handle) error {
|
||||
var fRename FILE_RENAME_INFO
|
||||
DS_STREAM_RENAME, err := windows.UTF16FromString(":deadbeef")
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
lpwStream := &DS_STREAM_RENAME[0]
|
||||
fRename.FileNameLength = uint32(unsafe.Sizeof(lpwStream))
|
||||
|
||||
windows.NewLazyDLL("kernel32.dll").NewProc("RtlCopyMemory").Call(
|
||||
uintptr(unsafe.Pointer(&fRename.FileName[0])),
|
||||
uintptr(unsafe.Pointer(lpwStream)),
|
||||
unsafe.Sizeof(lpwStream),
|
||||
)
|
||||
|
||||
err = windows.SetFileInformationByHandle(
|
||||
hHandle,
|
||||
windows.FileRenameInfo,
|
||||
(*byte)(unsafe.Pointer(&fRename)),
|
||||
uint32(unsafe.Sizeof(fRename)+unsafe.Sizeof(lpwStream)),
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// dsDepositeHandle marks a file handle for deletion
|
||||
func dsDepositeHandle(hHandle windows.Handle) error {
|
||||
var fDelete FILE_DISPOSITION_INFO
|
||||
fDelete.DeleteFile = true
|
||||
|
||||
err := windows.SetFileInformationByHandle(
|
||||
hHandle,
|
||||
windows.FileDispositionInfo,
|
||||
(*byte)(unsafe.Pointer(&fDelete)),
|
||||
uint32(unsafe.Sizeof(fDelete)),
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// selfDeleteExe performs self-deletion on Windows platforms
|
||||
func selfDeleteExe() error {
|
||||
var wcPath [windows.MAX_PATH + 1]uint16
|
||||
var hCurrent windows.Handle
|
||||
|
||||
_, err := windows.GetModuleFileName(0, &wcPath[0], windows.MAX_PATH)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
hCurrent, err = dsOpenHandle(&wcPath[0])
|
||||
if err != nil || hCurrent == windows.InvalidHandle {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := dsRenameHandle(hCurrent); err != nil {
|
||||
_ = windows.CloseHandle(hCurrent)
|
||||
return err
|
||||
}
|
||||
_ = windows.CloseHandle(hCurrent)
|
||||
|
||||
hCurrent, err = dsOpenHandle(&wcPath[0])
|
||||
if err != nil || hCurrent == windows.InvalidHandle {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := dsDepositeHandle(hCurrent); err != nil {
|
||||
_ = windows.CloseHandle(hCurrent)
|
||||
return err
|
||||
}
|
||||
|
||||
return windows.CloseHandle(hCurrent)
|
||||
}
|
||||
17
cmd/cli/self_kill_others.go
Normal file
17
cmd/cli/self_kill_others.go
Normal file
@@ -0,0 +1,17 @@
|
||||
//go:build !unix
|
||||
|
||||
package cli
|
||||
|
||||
import (
|
||||
"os"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld"
|
||||
)
|
||||
|
||||
// selfUninstall performs self-uninstallation on non-Unix platforms
|
||||
func selfUninstall(p *prog, logger *ctrld.Logger) {
|
||||
if uninstallInvalidCdUID(p, logger, false) {
|
||||
logger.Warn().Msgf("Service was uninstalled because device %q does not exist", cdUID)
|
||||
os.Exit(0)
|
||||
}
|
||||
}
|
||||
47
cmd/cli/self_kill_unix.go
Normal file
47
cmd/cli/self_kill_unix.go
Normal file
@@ -0,0 +1,47 @@
|
||||
//go:build unix
|
||||
|
||||
package cli
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"runtime"
|
||||
"syscall"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld"
|
||||
)
|
||||
|
||||
// selfUninstall performs self-uninstallation on Unix platforms
|
||||
func selfUninstall(p *prog, logger *ctrld.Logger) {
|
||||
if runtime.GOOS == "linux" {
|
||||
selfUninstallLinux(p, logger)
|
||||
}
|
||||
|
||||
bin, err := os.Executable()
|
||||
if err != nil {
|
||||
logger.Fatal().Err(err).Msg("Could not determine executable")
|
||||
}
|
||||
args := []string{"uninstall"}
|
||||
if deactivationPinSet() {
|
||||
args = append(args, fmt.Sprintf("--pin=%d", cdDeactivationPin.Load()))
|
||||
}
|
||||
cmd := exec.Command(bin, args...)
|
||||
cmd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true}
|
||||
if err := cmd.Start(); err != nil {
|
||||
logger.Fatal().Err(err).Msg("Could not start self uninstall command")
|
||||
}
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
logger.Warn().Msgf("Service was uninstalled because device %q does not exist", cdUID)
|
||||
_ = cmd.Wait()
|
||||
os.Exit(0)
|
||||
}
|
||||
|
||||
// selfUninstallLinux performs self-uninstallation on Linux platforms
|
||||
func selfUninstallLinux(p *prog, logger *ctrld.Logger) {
|
||||
if uninstallInvalidCdUID(p, logger, true) {
|
||||
logger.Warn().Msgf("Service was uninstalled because device %q does not exist", cdUID)
|
||||
os.Exit(0)
|
||||
}
|
||||
}
|
||||
12
cmd/cli/self_upgrade_others.go
Normal file
12
cmd/cli/self_upgrade_others.go
Normal file
@@ -0,0 +1,12 @@
|
||||
//go:build !windows
|
||||
|
||||
package cli
|
||||
|
||||
import (
|
||||
"syscall"
|
||||
)
|
||||
|
||||
// sysProcAttrForDetachedChildProcess returns *syscall.SysProcAttr instance for running a detached child command.
|
||||
func sysProcAttrForDetachedChildProcess() *syscall.SysProcAttr {
|
||||
return &syscall.SysProcAttr{Setsid: true}
|
||||
}
|
||||
18
cmd/cli/self_upgrade_windows.go
Normal file
18
cmd/cli/self_upgrade_windows.go
Normal file
@@ -0,0 +1,18 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"syscall"
|
||||
)
|
||||
|
||||
// From: https://learn.microsoft.com/en-us/windows/win32/procthread/process-creation-flags?redirectedfrom=MSDN
|
||||
|
||||
// SYSCALL_CREATE_NO_WINDOW set flag to run process without a console window.
|
||||
const SYSCALL_CREATE_NO_WINDOW = 0x08000000
|
||||
|
||||
// sysProcAttrForDetachedChildProcess returns *syscall.SysProcAttr instance for running self-upgrade command.
|
||||
func sysProcAttrForDetachedChildProcess() *syscall.SysProcAttr {
|
||||
return &syscall.SysProcAttr{
|
||||
CreationFlags: syscall.CREATE_NEW_PROCESS_GROUP | SYSCALL_CREATE_NO_WINDOW,
|
||||
HideWindow: true,
|
||||
}
|
||||
}
|
||||
31
cmd/cli/sema.go
Normal file
31
cmd/cli/sema.go
Normal file
@@ -0,0 +1,31 @@
|
||||
package cli
|
||||
|
||||
// semaphore provides a simple synchronization mechanism
|
||||
type semaphore interface {
|
||||
acquire()
|
||||
release()
|
||||
}
|
||||
|
||||
// noopSemaphore is a no-operation implementation of semaphore
|
||||
type noopSemaphore struct{}
|
||||
|
||||
// acquire performs a no-operation for the noop semaphore
|
||||
func (n noopSemaphore) acquire() {}
|
||||
|
||||
// release performs a no-operation for the noop semaphore
|
||||
func (n noopSemaphore) release() {}
|
||||
|
||||
// chanSemaphore is a channel-based implementation of semaphore
|
||||
type chanSemaphore struct {
|
||||
ready chan struct{}
|
||||
}
|
||||
|
||||
// acquire blocks until a slot is available in the semaphore
|
||||
func (c *chanSemaphore) acquire() {
|
||||
c.ready <- struct{}{}
|
||||
}
|
||||
|
||||
// release signals that a slot has been freed in the semaphore
|
||||
func (c *chanSemaphore) release() {
|
||||
<-c.ready
|
||||
}
|
||||
227
cmd/cli/service.go
Normal file
227
cmd/cli/service.go
Normal file
@@ -0,0 +1,227 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"os/exec"
|
||||
"runtime"
|
||||
|
||||
"github.com/coreos/go-systemd/v22/unit"
|
||||
"github.com/kardianos/service"
|
||||
)
|
||||
|
||||
// newService wraps service.New call to return service.Service
|
||||
// wrapper which is suitable for the current platform.
|
||||
func newService(i service.Interface, c *service.Config) (service.Service, error) {
|
||||
s, err := service.New(i, c)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
switch {
|
||||
case s.Platform() == "unix-systemv":
|
||||
return &sysV{s}, nil
|
||||
case s.Platform() == "linux-systemd":
|
||||
return &systemd{s}, nil
|
||||
case s.Platform() == "darwin-launchd":
|
||||
return newLaunchd(s), nil
|
||||
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// sysV wraps a service.Service, and provide start/stop/status command
|
||||
// base on "/etc/init.d/<service_name>".
|
||||
//
|
||||
// Use this on system where "service" command is not available.
|
||||
type sysV struct {
|
||||
service.Service
|
||||
}
|
||||
|
||||
func (s *sysV) installed() bool {
|
||||
fi, err := os.Stat("/etc/init.d/ctrld")
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
mode := fi.Mode()
|
||||
return mode.IsRegular() && (mode&0111) != 0
|
||||
}
|
||||
|
||||
func (s *sysV) Start() error {
|
||||
if !s.installed() {
|
||||
return service.ErrNotInstalled
|
||||
}
|
||||
_, err := exec.Command("/etc/init.d/ctrld", "start").CombinedOutput()
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *sysV) Stop() error {
|
||||
if !s.installed() {
|
||||
return service.ErrNotInstalled
|
||||
}
|
||||
_, err := exec.Command("/etc/init.d/ctrld", "stop").CombinedOutput()
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *sysV) Restart() error {
|
||||
if !s.installed() {
|
||||
return service.ErrNotInstalled
|
||||
}
|
||||
// We don't care about error returned by s.Stop,
|
||||
// because the service may already be stopped.
|
||||
_ = s.Stop()
|
||||
return s.Start()
|
||||
}
|
||||
|
||||
func (s *sysV) Status() (service.Status, error) {
|
||||
if !s.installed() {
|
||||
return service.StatusUnknown, service.ErrNotInstalled
|
||||
}
|
||||
return unixSystemVServiceStatus()
|
||||
}
|
||||
|
||||
// systemd wraps a service.Service, and provide status command to
|
||||
// report the status correctly.
|
||||
type systemd struct {
|
||||
service.Service
|
||||
}
|
||||
|
||||
func (s *systemd) Status() (service.Status, error) {
|
||||
out, _ := exec.Command("systemctl", "status", "ctrld").CombinedOutput()
|
||||
if bytes.Contains(out, []byte("/FAILURE)")) {
|
||||
return service.StatusStopped, nil
|
||||
}
|
||||
return s.Service.Status()
|
||||
}
|
||||
|
||||
func (s *systemd) Start() error {
|
||||
const systemdUnitFile = "/etc/systemd/system/ctrld.service"
|
||||
f, err := os.Open(systemdUnitFile)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer f.Close()
|
||||
if opts, change := ensureSystemdKillMode(f); change {
|
||||
mode := os.FileMode(0644)
|
||||
buf, err := io.ReadAll(unit.Serialize(opts))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := os.WriteFile(systemdUnitFile, buf, mode); err != nil {
|
||||
return err
|
||||
}
|
||||
if out, err := exec.Command("systemctl", "daemon-reload").CombinedOutput(); err != nil {
|
||||
return fmt.Errorf("systemctl daemon-reload failed: %w\n%s", err, string(out))
|
||||
}
|
||||
mainLog.Load().Debug().Msg("Set KillMode=process successfully")
|
||||
}
|
||||
return s.Service.Start()
|
||||
}
|
||||
|
||||
// ensureSystemdKillMode ensure systemd unit file is configured with KillMode=process.
|
||||
// This is necessary for running self-upgrade flow.
|
||||
func ensureSystemdKillMode(r io.Reader) (opts []*unit.UnitOption, change bool) {
|
||||
opts, err := unit.DeserializeOptions(r)
|
||||
if err != nil {
|
||||
mainLog.Load().Error().Err(err).Msg("Failed to deserialize options")
|
||||
return
|
||||
}
|
||||
change = true
|
||||
needKillModeOpt := true
|
||||
killModeOpt := unit.NewUnitOption("Service", "KillMode", "process")
|
||||
for _, opt := range opts {
|
||||
if opt.Match(killModeOpt) {
|
||||
needKillModeOpt = false
|
||||
change = false
|
||||
break
|
||||
}
|
||||
if opt.Section == killModeOpt.Section && opt.Name == killModeOpt.Name {
|
||||
opt.Value = killModeOpt.Value
|
||||
needKillModeOpt = false
|
||||
break
|
||||
}
|
||||
}
|
||||
if needKillModeOpt {
|
||||
opts = append(opts, killModeOpt)
|
||||
}
|
||||
return opts, change
|
||||
}
|
||||
|
||||
// newLaunchd creates a new launchd service wrapper
|
||||
func newLaunchd(s service.Service) *launchd {
|
||||
return &launchd{
|
||||
Service: s,
|
||||
statusErrMsg: "Permission denied",
|
||||
}
|
||||
}
|
||||
|
||||
// launchd wraps a service.Service, and provide status command to
|
||||
// report the status correctly when not running as root on Darwin.
|
||||
//
|
||||
// TODO: remove this wrapper once https://github.com/kardianos/service/issues/400 fixed.
|
||||
type launchd struct {
|
||||
service.Service
|
||||
statusErrMsg string
|
||||
}
|
||||
|
||||
func (l *launchd) Status() (service.Status, error) {
|
||||
if os.Geteuid() != 0 {
|
||||
return service.StatusUnknown, errors.New(l.statusErrMsg)
|
||||
}
|
||||
return l.Service.Status()
|
||||
}
|
||||
|
||||
type task struct {
|
||||
f func() error
|
||||
abortOnError bool
|
||||
Name string
|
||||
}
|
||||
|
||||
// doTasks executes a list of tasks and returns success status
|
||||
func doTasks(tasks []task) bool {
|
||||
for _, task := range tasks {
|
||||
mainLog.Load().Debug().Msgf("Running task %s", task.Name)
|
||||
if err := task.f(); err != nil {
|
||||
if task.abortOnError {
|
||||
mainLog.Load().Error().Msgf("Error running task %s: %v", task.Name, err)
|
||||
return false
|
||||
}
|
||||
// if this is darwin stop command, dont print debug
|
||||
// since launchctl complains on every start
|
||||
if runtime.GOOS != "darwin" || task.Name != "Stop" {
|
||||
mainLog.Load().Debug().Msgf("Error running task %s: %v", task.Name, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// checkHasElevatedPrivilege checks if the process has elevated privileges and exits if not
|
||||
func checkHasElevatedPrivilege() {
|
||||
ok, err := hasElevatedPrivilege()
|
||||
if err != nil {
|
||||
mainLog.Load().Error().Msgf("Could not detect user privilege: %v", err)
|
||||
return
|
||||
}
|
||||
if !ok {
|
||||
mainLog.Load().Error().Msg("Please relaunch process with admin/root privilege.")
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
// unixSystemVServiceStatus checks the status of a Unix System V service
|
||||
func unixSystemVServiceStatus() (service.Status, error) {
|
||||
out, err := exec.Command("/etc/init.d/ctrld", "status").CombinedOutput()
|
||||
if err != nil {
|
||||
return service.StatusUnknown, nil
|
||||
}
|
||||
|
||||
switch string(bytes.ToLower(bytes.TrimSpace(out))) {
|
||||
case "running":
|
||||
return service.StatusRunning, nil
|
||||
default:
|
||||
return service.StatusStopped, nil
|
||||
}
|
||||
}
|
||||
20
cmd/cli/service_others.go
Normal file
20
cmd/cli/service_others.go
Normal file
@@ -0,0 +1,20 @@
|
||||
//go:build !windows
|
||||
|
||||
package cli
|
||||
|
||||
import (
|
||||
"os"
|
||||
)
|
||||
|
||||
// hasElevatedPrivilege checks if the current process has elevated privileges
|
||||
func hasElevatedPrivilege() (bool, error) {
|
||||
return os.Geteuid() == 0, nil
|
||||
}
|
||||
|
||||
// openLogFile opens a log file with the specified flags
|
||||
func openLogFile(path string, flags int) (*os.File, error) {
|
||||
return os.OpenFile(path, flags, os.FileMode(0o600))
|
||||
}
|
||||
|
||||
// ConfigureWindowsServiceFailureActions is a no-op on non-Windows platforms
|
||||
func ConfigureWindowsServiceFailureActions(serviceName string) error { return nil }
|
||||
28
cmd/cli/service_test.go
Normal file
28
cmd/cli/service_test.go
Normal file
@@ -0,0 +1,28 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func Test_ensureSystemdKillMode(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
unitFile string
|
||||
wantChange bool
|
||||
}{
|
||||
{"no KillMode", "[Service]\nExecStart=/bin/sleep 1", true},
|
||||
{"not KillMode=process", "[Service]\nExecStart=/bin/sleep 1\nKillMode=mixed", true},
|
||||
{"KillMode=process", "[Service]\nExecStart=/bin/sleep 1\nKillMode=process", false},
|
||||
{"invalid unit file", "[Service\nExecStart=/bin/sleep 1\nKillMode=process", false},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
if _, change := ensureSystemdKillMode(strings.NewReader(tc.unitFile)); tc.wantChange != change {
|
||||
t.Errorf("ensureSystemdKillMode(%q) = %v, want %v", tc.unitFile, change, tc.wantChange)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
148
cmd/cli/service_windows.go
Normal file
148
cmd/cli/service_windows.go
Normal file
@@ -0,0 +1,148 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"os"
|
||||
"runtime"
|
||||
"syscall"
|
||||
"time"
|
||||
"unsafe"
|
||||
|
||||
"golang.org/x/sys/windows"
|
||||
"golang.org/x/sys/windows/svc/mgr"
|
||||
)
|
||||
|
||||
// hasElevatedPrivilege checks if the current process has elevated privileges on Windows
|
||||
func hasElevatedPrivilege() (bool, error) {
|
||||
var sid *windows.SID
|
||||
if err := windows.AllocateAndInitializeSid(
|
||||
&windows.SECURITY_NT_AUTHORITY,
|
||||
2,
|
||||
windows.SECURITY_BUILTIN_DOMAIN_RID,
|
||||
windows.DOMAIN_ALIAS_RID_ADMINS,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
&sid,
|
||||
); err != nil {
|
||||
return false, err
|
||||
}
|
||||
token := windows.Token(0)
|
||||
return token.IsMember(sid)
|
||||
}
|
||||
|
||||
// ConfigureWindowsServiceFailureActions checks if the given service
|
||||
// has the correct failure actions configured, and updates them if not.
|
||||
func ConfigureWindowsServiceFailureActions(serviceName string) error {
|
||||
if runtime.GOOS != "windows" {
|
||||
return nil // no-op on non-Windows
|
||||
}
|
||||
|
||||
m, err := mgr.Connect()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer m.Disconnect()
|
||||
|
||||
s, err := m.OpenService(serviceName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer s.Close()
|
||||
|
||||
// 1. Retrieve the current config
|
||||
cfg, err := s.Config()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 2. Update the Description
|
||||
cfg.Description = "A highly configurable, multi-protocol DNS forwarding proxy"
|
||||
|
||||
// 3. Apply the updated config
|
||||
if err := s.UpdateConfig(cfg); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Then proceed with existing actions, e.g. setting failure actions
|
||||
actions := []mgr.RecoveryAction{
|
||||
{Type: mgr.ServiceRestart, Delay: time.Second * 5}, // 5 seconds
|
||||
{Type: mgr.ServiceRestart, Delay: time.Second * 5}, // 5 seconds
|
||||
{Type: mgr.ServiceRestart, Delay: time.Second * 5}, // 5 seconds
|
||||
}
|
||||
|
||||
// Set the recovery actions (3 restarts, reset period = 120).
|
||||
err = s.SetRecoveryActions(actions, 120)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Ensure that failure actions are NOT triggered on user-initiated stops.
|
||||
var failureActionsFlag windows.SERVICE_FAILURE_ACTIONS_FLAG
|
||||
failureActionsFlag.FailureActionsOnNonCrashFailures = 0
|
||||
|
||||
if err := windows.ChangeServiceConfig2(
|
||||
s.Handle,
|
||||
windows.SERVICE_CONFIG_FAILURE_ACTIONS_FLAG,
|
||||
(*byte)(unsafe.Pointer(&failureActionsFlag)),
|
||||
); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// openLogFile opens a log file with the specified mode on Windows
|
||||
func openLogFile(path string, mode int) (*os.File, error) {
|
||||
if len(path) == 0 {
|
||||
return nil, &os.PathError{Path: path, Op: "open", Err: syscall.ERROR_FILE_NOT_FOUND}
|
||||
}
|
||||
|
||||
pathP, err := syscall.UTF16PtrFromString(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var access uint32
|
||||
switch mode & (os.O_RDONLY | os.O_WRONLY | os.O_RDWR) {
|
||||
case os.O_RDONLY:
|
||||
access = windows.GENERIC_READ
|
||||
case os.O_WRONLY:
|
||||
access = windows.GENERIC_WRITE
|
||||
case os.O_RDWR:
|
||||
access = windows.GENERIC_READ | windows.GENERIC_WRITE
|
||||
}
|
||||
if mode&os.O_CREATE != 0 {
|
||||
access |= windows.GENERIC_WRITE
|
||||
}
|
||||
if mode&os.O_APPEND != 0 {
|
||||
access &^= windows.GENERIC_WRITE
|
||||
access |= windows.FILE_APPEND_DATA
|
||||
}
|
||||
|
||||
shareMode := uint32(syscall.FILE_SHARE_READ | syscall.FILE_SHARE_WRITE | syscall.FILE_SHARE_DELETE)
|
||||
|
||||
var sa *syscall.SecurityAttributes
|
||||
|
||||
var createMode uint32
|
||||
switch {
|
||||
case mode&(os.O_CREATE|os.O_EXCL) == (os.O_CREATE | os.O_EXCL):
|
||||
createMode = windows.CREATE_NEW
|
||||
case mode&(os.O_CREATE|os.O_TRUNC) == (os.O_CREATE | os.O_TRUNC):
|
||||
createMode = windows.CREATE_ALWAYS
|
||||
case mode&os.O_CREATE == os.O_CREATE:
|
||||
createMode = windows.OPEN_ALWAYS
|
||||
case mode&os.O_TRUNC == os.O_TRUNC:
|
||||
createMode = windows.TRUNCATE_EXISTING
|
||||
default:
|
||||
createMode = windows.OPEN_EXISTING
|
||||
}
|
||||
|
||||
handle, err := syscall.CreateFile(pathP, access, shareMode, sa, createMode, syscall.FILE_ATTRIBUTE_NORMAL, 0)
|
||||
if err != nil {
|
||||
return nil, &os.PathError{Path: path, Op: "open", Err: err}
|
||||
}
|
||||
|
||||
return os.NewFile(uintptr(handle), path), nil
|
||||
}
|
||||
130
cmd/cli/upstream_monitor.go
Normal file
130
cmd/cli/upstream_monitor.go
Normal file
@@ -0,0 +1,130 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld"
|
||||
)
|
||||
|
||||
const (
|
||||
// maxFailureRequest is the maximum failed queries allowed before an upstream is marked as down.
|
||||
maxFailureRequest = 50
|
||||
// checkUpstreamBackoffSleep is the time interval between each upstream checks.
|
||||
checkUpstreamBackoffSleep = 2 * time.Second
|
||||
)
|
||||
|
||||
// upstreamMonitor performs monitoring upstreams health.
|
||||
type upstreamMonitor struct {
|
||||
cfg *ctrld.Config
|
||||
logger atomic.Pointer[ctrld.Logger]
|
||||
|
||||
mu sync.RWMutex
|
||||
checking map[string]bool
|
||||
down map[string]bool
|
||||
failureReq map[string]uint64
|
||||
recovered map[string]bool
|
||||
|
||||
// failureTimerActive tracks if a timer is already running for a given upstream.
|
||||
failureTimerActive map[string]bool
|
||||
}
|
||||
|
||||
// newUpstreamMonitor creates a new upstream monitor instance
|
||||
func newUpstreamMonitor(cfg *ctrld.Config, logger *ctrld.Logger) *upstreamMonitor {
|
||||
um := &upstreamMonitor{
|
||||
cfg: cfg,
|
||||
checking: make(map[string]bool),
|
||||
down: make(map[string]bool),
|
||||
failureReq: make(map[string]uint64),
|
||||
recovered: make(map[string]bool),
|
||||
failureTimerActive: make(map[string]bool),
|
||||
}
|
||||
um.logger.Store(logger)
|
||||
for n := range cfg.Upstream {
|
||||
upstream := upstreamPrefix + n
|
||||
um.reset(upstream)
|
||||
}
|
||||
um.reset(upstreamOS)
|
||||
return um
|
||||
}
|
||||
|
||||
// increaseFailureCount increases failed queries count for an upstream by 1 and logs debug information.
|
||||
// It uses a timer to debounce failure detection, ensuring that an upstream is marked as down
|
||||
// within 10 seconds if failures persist, without spawning duplicate goroutines.
|
||||
func (um *upstreamMonitor) increaseFailureCount(upstream string) {
|
||||
um.mu.Lock()
|
||||
defer um.mu.Unlock()
|
||||
|
||||
if um.recovered[upstream] {
|
||||
um.logger.Load().Debug().Msgf("Upstream %q is recovered, skipping failure count increase", upstream)
|
||||
return
|
||||
}
|
||||
|
||||
um.failureReq[upstream] += 1
|
||||
failedCount := um.failureReq[upstream]
|
||||
|
||||
// Log the updated failure count.
|
||||
um.logger.Load().Debug().Msgf("Upstream %q failure count updated to %d", upstream, failedCount)
|
||||
|
||||
// If this is the first failure and no timer is running, start a 10-second timer.
|
||||
if failedCount == 1 && !um.failureTimerActive[upstream] {
|
||||
um.failureTimerActive[upstream] = true
|
||||
go func(upstream string) {
|
||||
time.Sleep(10 * time.Second)
|
||||
um.mu.Lock()
|
||||
defer um.mu.Unlock()
|
||||
// If no success occurred during the 10-second window (i.e. counter remains > 0)
|
||||
// and the upstream is not in a recovered state, mark it as down.
|
||||
if um.failureReq[upstream] > 0 && !um.recovered[upstream] {
|
||||
um.down[upstream] = true
|
||||
um.logger.Load().Warn().Msgf("Upstream %q marked as down after 10 seconds (failure count: %d)", upstream, um.failureReq[upstream])
|
||||
}
|
||||
// Reset the timer flag so that a new timer can be spawned if needed.
|
||||
um.failureTimerActive[upstream] = false
|
||||
}(upstream)
|
||||
}
|
||||
|
||||
// If the failure count quickly reaches the threshold, mark the upstream as down immediately.
|
||||
if failedCount >= maxFailureRequest {
|
||||
um.down[upstream] = true
|
||||
um.logger.Load().Warn().Msgf("Upstream %q marked as down immediately (failure count: %d)", upstream, failedCount)
|
||||
}
|
||||
}
|
||||
|
||||
// isDown reports whether the given upstream is being marked as down.
|
||||
func (um *upstreamMonitor) isDown(upstream string) bool {
|
||||
um.mu.Lock()
|
||||
defer um.mu.Unlock()
|
||||
|
||||
return um.down[upstream]
|
||||
}
|
||||
|
||||
// reset marks an upstream as up and set failed queries counter to zero.
|
||||
func (um *upstreamMonitor) reset(upstream string) {
|
||||
um.mu.Lock()
|
||||
um.failureReq[upstream] = 0
|
||||
um.down[upstream] = false
|
||||
um.recovered[upstream] = true
|
||||
um.mu.Unlock()
|
||||
go func() {
|
||||
// debounce the recovery to avoid incrementing failure counts already in flight
|
||||
time.Sleep(1 * time.Second)
|
||||
um.mu.Lock()
|
||||
um.recovered[upstream] = false
|
||||
um.mu.Unlock()
|
||||
}()
|
||||
}
|
||||
|
||||
// countHealthy returns the number of upstreams in the provided map that are considered healthy.
|
||||
func (um *upstreamMonitor) countHealthy(upstreams []string) int {
|
||||
var count int
|
||||
um.mu.RLock()
|
||||
for _, upstream := range upstreams {
|
||||
if !um.down[upstream] {
|
||||
count++
|
||||
}
|
||||
}
|
||||
um.mu.RUnlock()
|
||||
return count
|
||||
}
|
||||
20
cmd/cli/winres/winres.json
Normal file
20
cmd/cli/winres/winres.json
Normal file
@@ -0,0 +1,20 @@
|
||||
{
|
||||
"RT_VERSION": {
|
||||
"#1": {
|
||||
"0000": {
|
||||
"fixed": {
|
||||
"file_version": "0.0.0.1"
|
||||
},
|
||||
"info": {
|
||||
"0409": {
|
||||
"CompanyName": "ControlD Inc",
|
||||
"FileDescription": "Control D DNS daemon",
|
||||
"ProductName": "ctrld",
|
||||
"InternalName": "ctrld",
|
||||
"LegalCopyright": "ControlD Inc 2024"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
4
cmd/cli/winres_windows.go
Normal file
4
cmd/cli/winres_windows.go
Normal file
@@ -0,0 +1,4 @@
|
||||
//go:generate go-winres make --product-version=git-tag --file-version=git-tag
|
||||
package cli
|
||||
|
||||
// Placeholder file for windows builds.
|
||||
959
cmd/ctrld/cli.go
959
cmd/ctrld/cli.go
@@ -1,959 +0,0 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/x509"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/cuonglm/osinfo"
|
||||
"github.com/fsnotify/fsnotify"
|
||||
"github.com/go-playground/validator/v10"
|
||||
"github.com/kardianos/service"
|
||||
"github.com/miekg/dns"
|
||||
"github.com/pelletier/go-toml/v2"
|
||||
"github.com/spf13/cobra"
|
||||
"github.com/spf13/viper"
|
||||
"tailscale.com/logtail/backoff"
|
||||
"tailscale.com/net/interfaces"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld"
|
||||
"github.com/Control-D-Inc/ctrld/internal/certs"
|
||||
"github.com/Control-D-Inc/ctrld/internal/controld"
|
||||
ctrldnet "github.com/Control-D-Inc/ctrld/internal/net"
|
||||
"github.com/Control-D-Inc/ctrld/internal/router"
|
||||
)
|
||||
|
||||
var (
|
||||
version = "dev"
|
||||
commit = "none"
|
||||
)
|
||||
|
||||
var (
|
||||
v = viper.NewWithOptions(viper.KeyDelimiter("::"))
|
||||
defaultConfigWritten = false
|
||||
defaultConfigFile = "ctrld.toml"
|
||||
rootCertPool *x509.CertPool
|
||||
)
|
||||
|
||||
var basicModeFlags = []string{"listen", "primary_upstream", "secondary_upstream", "domains"}
|
||||
|
||||
func isNoConfigStart(cmd *cobra.Command) bool {
|
||||
for _, flagName := range basicModeFlags {
|
||||
if cmd.Flags().Lookup(flagName).Changed {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
const rootShortDesc = `
|
||||
__ .__ .___
|
||||
_____/ |________| | __| _/
|
||||
_/ ___\ __\_ __ \ | / __ |
|
||||
\ \___| | | | \/ |__/ /_/ |
|
||||
\___ >__| |__| |____/\____ |
|
||||
\/ dns forwarding proxy \/
|
||||
`
|
||||
|
||||
var rootCmd = &cobra.Command{
|
||||
Use: "ctrld",
|
||||
Short: strings.TrimLeft(rootShortDesc, "\n"),
|
||||
Version: curVersion(),
|
||||
PreRun: func(cmd *cobra.Command, args []string) {
|
||||
initConsoleLogging()
|
||||
},
|
||||
}
|
||||
|
||||
func curVersion() string {
|
||||
if version != "dev" && !strings.HasPrefix(version, "v") {
|
||||
version = "v" + version
|
||||
}
|
||||
if len(commit) > 7 {
|
||||
commit = commit[:7]
|
||||
}
|
||||
return fmt.Sprintf("%s-%s", version, commit)
|
||||
}
|
||||
|
||||
func initCLI() {
|
||||
// Enable opening via explorer.exe on Windows.
|
||||
// See: https://github.com/spf13/cobra/issues/844.
|
||||
cobra.MousetrapHelpText = ""
|
||||
cobra.EnableCommandSorting = false
|
||||
|
||||
rootCmd.PersistentFlags().CountVarP(
|
||||
&verbose,
|
||||
"verbose",
|
||||
"v",
|
||||
`verbose log output, "-v" basic logging, "-vv" debug level logging`,
|
||||
)
|
||||
rootCmd.PersistentFlags().BoolVarP(
|
||||
&silent,
|
||||
"silent",
|
||||
"s",
|
||||
false,
|
||||
`do not write any log output`,
|
||||
)
|
||||
rootCmd.SetHelpCommand(&cobra.Command{Hidden: true})
|
||||
rootCmd.CompletionOptions.HiddenDefaultCmd = true
|
||||
|
||||
runCmd := &cobra.Command{
|
||||
Use: "run",
|
||||
Short: "Run the DNS proxy server",
|
||||
Args: cobra.NoArgs,
|
||||
PreRun: func(cmd *cobra.Command, args []string) {
|
||||
initConsoleLogging()
|
||||
},
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
if daemon && runtime.GOOS == "windows" {
|
||||
mainLog.Fatal().Msg("Cannot run in daemon mode. Please install a Windows service.")
|
||||
}
|
||||
|
||||
waitCh := make(chan struct{})
|
||||
stopCh := make(chan struct{})
|
||||
if !daemon {
|
||||
// We need to call s.Run() as soon as possible to response to the OS manager, so it
|
||||
// can see ctrld is running and don't mark ctrld as failed service.
|
||||
go func() {
|
||||
p := &prog{
|
||||
waitCh: waitCh,
|
||||
stopCh: stopCh,
|
||||
}
|
||||
s, err := service.New(p, svcConfig)
|
||||
if err != nil {
|
||||
mainLog.Fatal().Err(err).Msg("failed create new service")
|
||||
}
|
||||
s = newService(s)
|
||||
if err := s.Run(); err != nil {
|
||||
mainLog.Error().Err(err).Msg("failed to start service")
|
||||
}
|
||||
}()
|
||||
}
|
||||
noConfigStart := isNoConfigStart(cmd)
|
||||
writeDefaultConfig := !noConfigStart && configBase64 == ""
|
||||
tryReadingConfig(writeDefaultConfig)
|
||||
|
||||
readBase64Config(configBase64)
|
||||
processNoConfigFlags(noConfigStart)
|
||||
if err := v.Unmarshal(&cfg); err != nil {
|
||||
mainLog.Fatal().Msgf("failed to unmarshal config: %v", err)
|
||||
}
|
||||
|
||||
mainLog.Info().Msgf("starting ctrld %s", curVersion())
|
||||
oi := osinfo.New()
|
||||
mainLog.Info().Msgf("os: %s", oi.String())
|
||||
|
||||
// Wait for network up.
|
||||
if !ctrldnet.Up() {
|
||||
mainLog.Fatal().Msg("network is not up yet")
|
||||
}
|
||||
processLogAndCacheFlags()
|
||||
// Log config do not have thing to validate, so it's safe to init log here,
|
||||
// so it's able to log information in processCDFlags.
|
||||
initLogging()
|
||||
|
||||
if setupRouter {
|
||||
s, errCh := runDNSServerForNTPD(router.ListenAddress())
|
||||
if err := router.PreRun(); err != nil {
|
||||
mainLog.Fatal().Err(err).Msg("failed to perform router pre-start check")
|
||||
}
|
||||
if err := s.Shutdown(); err != nil && errCh != nil {
|
||||
mainLog.Fatal().Err(err).Msg("failed to shutdown dns server for ntpd")
|
||||
}
|
||||
}
|
||||
|
||||
processCDFlags()
|
||||
if err := ctrld.ValidateConfig(validator.New(), &cfg); err != nil {
|
||||
mainLog.Fatal().Msgf("invalid config: %v", err)
|
||||
}
|
||||
initCache()
|
||||
|
||||
if daemon {
|
||||
exe, err := os.Executable()
|
||||
if err != nil {
|
||||
mainLog.Error().Err(err).Msg("failed to find the binary")
|
||||
os.Exit(1)
|
||||
}
|
||||
curDir, err := os.Getwd()
|
||||
if err != nil {
|
||||
mainLog.Error().Err(err).Msg("failed to get current working directory")
|
||||
os.Exit(1)
|
||||
}
|
||||
// If running as daemon, re-run the command in background, with daemon off.
|
||||
cmd := exec.Command(exe, append(os.Args[1:], "-d=false")...)
|
||||
cmd.Dir = curDir
|
||||
if err := cmd.Start(); err != nil {
|
||||
mainLog.Error().Err(err).Msg("failed to start process as daemon")
|
||||
os.Exit(1)
|
||||
}
|
||||
mainLog.Info().Int("pid", cmd.Process.Pid).Msg("DNS proxy started")
|
||||
os.Exit(0)
|
||||
}
|
||||
|
||||
if setupRouter {
|
||||
switch platform := router.Name(); {
|
||||
case platform == router.DDWrt:
|
||||
rootCertPool = certs.CACertPool()
|
||||
fallthrough
|
||||
case platform != "":
|
||||
mainLog.Debug().Msg("Router setup")
|
||||
err := router.Configure(&cfg)
|
||||
if errors.Is(err, router.ErrNotSupported) {
|
||||
unsupportedPlatformHelp(cmd)
|
||||
os.Exit(1)
|
||||
}
|
||||
if err != nil {
|
||||
mainLog.Fatal().Err(err).Msg("failed to configure router")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
close(waitCh)
|
||||
<-stopCh
|
||||
},
|
||||
}
|
||||
runCmd.Flags().BoolVarP(&daemon, "daemon", "d", false, "Run as daemon")
|
||||
runCmd.Flags().StringVarP(&configPath, "config", "c", "", "Path to config file")
|
||||
runCmd.Flags().StringVarP(&configBase64, "base64_config", "", "", "Base64 encoded config")
|
||||
runCmd.Flags().StringVarP(&listenAddress, "listen", "", "", "Listener address and port, in format: address:port")
|
||||
runCmd.Flags().StringVarP(&primaryUpstream, "primary_upstream", "", "", "Primary upstream endpoint")
|
||||
runCmd.Flags().StringVarP(&secondaryUpstream, "secondary_upstream", "", "", "Secondary upstream endpoint")
|
||||
runCmd.Flags().StringSliceVarP(&domains, "domains", "", nil, "List of domain to apply in a split DNS policy")
|
||||
runCmd.Flags().StringVarP(&logPath, "log", "", "", "Path to log file")
|
||||
runCmd.Flags().IntVarP(&cacheSize, "cache_size", "", 0, "Enable cache with size items")
|
||||
runCmd.Flags().StringVarP(&cdUID, "cd", "", "", "Control D resolver uid")
|
||||
runCmd.Flags().BoolVarP(&cdDev, "dev", "", false, "Use Control D dev resolver/domain")
|
||||
_ = runCmd.Flags().MarkHidden("dev")
|
||||
runCmd.Flags().StringVarP(&homedir, "homedir", "", "", "")
|
||||
_ = runCmd.Flags().MarkHidden("homedir")
|
||||
runCmd.Flags().StringVarP(&iface, "iface", "", "", `Update DNS setting for iface, "auto" means the default interface gateway`)
|
||||
_ = runCmd.Flags().MarkHidden("iface")
|
||||
runCmd.Flags().BoolVarP(&setupRouter, "router", "", false, `setup for running on router platforms`)
|
||||
_ = runCmd.Flags().MarkHidden("router")
|
||||
|
||||
rootCmd.AddCommand(runCmd)
|
||||
|
||||
startCmd := &cobra.Command{
|
||||
PreRun: func(cmd *cobra.Command, args []string) {
|
||||
initConsoleLogging()
|
||||
checkHasElevatedPrivilege()
|
||||
},
|
||||
Use: "start",
|
||||
Short: "Install and start the ctrld service",
|
||||
Args: cobra.NoArgs,
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
sc := &service.Config{}
|
||||
*sc = *svcConfig
|
||||
osArgs := os.Args[2:]
|
||||
if os.Args[1] == "service" {
|
||||
osArgs = os.Args[3:]
|
||||
}
|
||||
setDependencies(sc)
|
||||
sc.Arguments = append([]string{"run"}, osArgs...)
|
||||
if err := router.ConfigureService(sc); err != nil {
|
||||
mainLog.Fatal().Err(err).Msg("failed to configure service on router")
|
||||
}
|
||||
|
||||
// No config path, generating config in HOME directory.
|
||||
noConfigStart := isNoConfigStart(cmd)
|
||||
writeDefaultConfig := !noConfigStart && configBase64 == ""
|
||||
if configPath != "" {
|
||||
v.SetConfigFile(configPath)
|
||||
}
|
||||
if dir, err := userHomeDir(); err == nil {
|
||||
setWorkingDirectory(sc, dir)
|
||||
if configPath == "" && writeDefaultConfig {
|
||||
defaultConfigFile = filepath.Join(dir, defaultConfigFile)
|
||||
}
|
||||
sc.Arguments = append(sc.Arguments, "--homedir="+dir)
|
||||
}
|
||||
|
||||
tryReadingConfig(writeDefaultConfig)
|
||||
|
||||
if err := v.Unmarshal(&cfg); err != nil {
|
||||
mainLog.Fatal().Msgf("failed to unmarshal config: %v", err)
|
||||
}
|
||||
|
||||
logPath := cfg.Service.LogPath
|
||||
cfg.Service.LogPath = ""
|
||||
initLogging()
|
||||
cfg.Service.LogPath = logPath
|
||||
|
||||
processCDFlags()
|
||||
|
||||
if err := ctrld.ValidateConfig(validator.New(), &cfg); err != nil {
|
||||
mainLog.Fatal().Msgf("invalid config: %v", err)
|
||||
}
|
||||
|
||||
// Explicitly passing config, so on system where home directory could not be obtained,
|
||||
// or sub-process env is different with the parent, we still behave correctly and use
|
||||
// the expected config file.
|
||||
if configPath == "" {
|
||||
sc.Arguments = append(sc.Arguments, "--config="+defaultConfigFile)
|
||||
}
|
||||
|
||||
prog := &prog{}
|
||||
s, err := service.New(prog, sc)
|
||||
if err != nil {
|
||||
mainLog.Error().Msg(err.Error())
|
||||
return
|
||||
}
|
||||
s = newService(s)
|
||||
tasks := []task{
|
||||
{s.Stop, false},
|
||||
{s.Uninstall, false},
|
||||
{s.Install, false},
|
||||
{s.Start, true},
|
||||
}
|
||||
if doTasks(tasks) {
|
||||
if err := router.PostInstall(svcConfig); err != nil {
|
||||
mainLog.Warn().Err(err).Msg("post installation failed, please check system/service log for details error")
|
||||
return
|
||||
}
|
||||
status, err := serviceStatus(s)
|
||||
if err != nil {
|
||||
mainLog.Warn().Err(err).Msg("could not get service status")
|
||||
return
|
||||
}
|
||||
|
||||
domain := cfg.Upstream["0"].VerifyDomain()
|
||||
status = selfCheckStatus(status, domain)
|
||||
switch status {
|
||||
case service.StatusRunning:
|
||||
mainLog.Notice().Msg("Service started")
|
||||
default:
|
||||
mainLog.Error().Msg("Service did not start, please check system/service log for details error")
|
||||
if runtime.GOOS == "linux" {
|
||||
prog.resetDNS()
|
||||
}
|
||||
os.Exit(1)
|
||||
}
|
||||
prog.setDNS()
|
||||
}
|
||||
},
|
||||
}
|
||||
// Keep these flags in sync with runCmd above, except for "-d".
|
||||
startCmd.Flags().StringVarP(&configPath, "config", "c", "", "Path to config file")
|
||||
startCmd.Flags().StringVarP(&configBase64, "base64_config", "", "", "Base64 encoded config")
|
||||
startCmd.Flags().StringVarP(&listenAddress, "listen", "", "", "Listener address and port, in format: address:port")
|
||||
startCmd.Flags().StringVarP(&primaryUpstream, "primary_upstream", "", "", "Primary upstream endpoint")
|
||||
startCmd.Flags().StringVarP(&secondaryUpstream, "secondary_upstream", "", "", "Secondary upstream endpoint")
|
||||
startCmd.Flags().StringSliceVarP(&domains, "domains", "", nil, "List of domain to apply in a split DNS policy")
|
||||
startCmd.Flags().StringVarP(&logPath, "log", "", "", "Path to log file")
|
||||
startCmd.Flags().IntVarP(&cacheSize, "cache_size", "", 0, "Enable cache with size items")
|
||||
startCmd.Flags().StringVarP(&cdUID, "cd", "", "", "Control D resolver uid")
|
||||
startCmd.Flags().BoolVarP(&cdDev, "dev", "", false, "Use Control D dev resolver/domain")
|
||||
_ = startCmd.Flags().MarkHidden("dev")
|
||||
startCmd.Flags().StringVarP(&iface, "iface", "", "", `Update DNS setting for iface, "auto" means the default interface gateway`)
|
||||
startCmd.Flags().BoolVarP(&setupRouter, "router", "", false, `setup for running on router platforms`)
|
||||
_ = startCmd.Flags().MarkHidden("router")
|
||||
|
||||
stopCmd := &cobra.Command{
|
||||
PreRun: func(cmd *cobra.Command, args []string) {
|
||||
initConsoleLogging()
|
||||
checkHasElevatedPrivilege()
|
||||
},
|
||||
Use: "stop",
|
||||
Short: "Stop the ctrld service",
|
||||
Args: cobra.NoArgs,
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
prog := &prog{}
|
||||
s, err := service.New(prog, svcConfig)
|
||||
if err != nil {
|
||||
mainLog.Error().Msg(err.Error())
|
||||
return
|
||||
}
|
||||
s = newService(s)
|
||||
initLogging()
|
||||
if doTasks([]task{{s.Stop, true}}) {
|
||||
prog.resetDNS()
|
||||
mainLog.Notice().Msg("Service stopped")
|
||||
}
|
||||
},
|
||||
}
|
||||
stopCmd.Flags().StringVarP(&iface, "iface", "", "", `Reset DNS setting for iface, "auto" means the default interface gateway`)
|
||||
|
||||
restartCmd := &cobra.Command{
|
||||
PreRun: func(cmd *cobra.Command, args []string) {
|
||||
initConsoleLogging()
|
||||
checkHasElevatedPrivilege()
|
||||
},
|
||||
Use: "restart",
|
||||
Short: "Restart the ctrld service",
|
||||
Args: cobra.NoArgs,
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
s, err := service.New(&prog{}, svcConfig)
|
||||
if err != nil {
|
||||
mainLog.Error().Msg(err.Error())
|
||||
return
|
||||
}
|
||||
s = newService(s)
|
||||
initLogging()
|
||||
if doTasks([]task{{s.Restart, true}}) {
|
||||
mainLog.Notice().Msg("Service restarted")
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
statusCmd := &cobra.Command{
|
||||
Use: "status",
|
||||
Short: "Show status of the ctrld service",
|
||||
Args: cobra.NoArgs,
|
||||
PreRun: func(cmd *cobra.Command, args []string) {
|
||||
initConsoleLogging()
|
||||
},
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
s, err := service.New(&prog{}, svcConfig)
|
||||
if err != nil {
|
||||
mainLog.Error().Msg(err.Error())
|
||||
return
|
||||
}
|
||||
s = newService(s)
|
||||
status, err := serviceStatus(s)
|
||||
if err != nil {
|
||||
mainLog.Error().Msg(err.Error())
|
||||
os.Exit(1)
|
||||
}
|
||||
switch status {
|
||||
case service.StatusUnknown:
|
||||
mainLog.Notice().Msg("Unknown status")
|
||||
os.Exit(2)
|
||||
case service.StatusRunning:
|
||||
mainLog.Notice().Msg("Service is running")
|
||||
os.Exit(0)
|
||||
case service.StatusStopped:
|
||||
mainLog.Notice().Msg("Service is stopped")
|
||||
os.Exit(1)
|
||||
}
|
||||
},
|
||||
}
|
||||
if runtime.GOOS == "darwin" {
|
||||
// On darwin, running status command without privileges may return wrong information.
|
||||
statusCmd.PreRun = func(cmd *cobra.Command, args []string) {
|
||||
initConsoleLogging()
|
||||
checkHasElevatedPrivilege()
|
||||
}
|
||||
}
|
||||
|
||||
uninstallCmd := &cobra.Command{
|
||||
PreRun: func(cmd *cobra.Command, args []string) {
|
||||
initConsoleLogging()
|
||||
checkHasElevatedPrivilege()
|
||||
},
|
||||
Use: "uninstall",
|
||||
Short: "Stop and uninstall the ctrld service",
|
||||
Long: `Stop and uninstall the ctrld service.
|
||||
|
||||
NOTE: Uninstalling will set DNS to values provided by DHCP.`,
|
||||
Args: cobra.NoArgs,
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
prog := &prog{}
|
||||
s, err := service.New(prog, svcConfig)
|
||||
if err != nil {
|
||||
mainLog.Error().Msg(err.Error())
|
||||
return
|
||||
}
|
||||
tasks := []task{
|
||||
{s.Stop, false},
|
||||
{s.Uninstall, true},
|
||||
}
|
||||
initLogging()
|
||||
if doTasks(tasks) {
|
||||
if iface == "" {
|
||||
iface = "auto"
|
||||
}
|
||||
prog.resetDNS()
|
||||
mainLog.Debug().Msg("Router cleanup")
|
||||
if err := router.Cleanup(svcConfig); err != nil {
|
||||
mainLog.Warn().Err(err).Msg("could not cleanup router")
|
||||
}
|
||||
mainLog.Notice().Msg("Service uninstalled")
|
||||
return
|
||||
}
|
||||
},
|
||||
}
|
||||
uninstallCmd.Flags().StringVarP(&iface, "iface", "", "", `Reset DNS setting for iface, use "auto" for the default gateway interface`)
|
||||
|
||||
listIfacesCmd := &cobra.Command{
|
||||
Use: "list",
|
||||
Short: "List network interfaces of the host",
|
||||
Args: cobra.NoArgs,
|
||||
PreRun: func(cmd *cobra.Command, args []string) {
|
||||
initConsoleLogging()
|
||||
},
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
err := interfaces.ForeachInterface(func(i interfaces.Interface, prefixes []netip.Prefix) {
|
||||
fmt.Printf("Index : %d\n", i.Index)
|
||||
fmt.Printf("Name : %s\n", i.Name)
|
||||
addrs, _ := i.Addrs()
|
||||
for i, ipaddr := range addrs {
|
||||
if i == 0 {
|
||||
fmt.Printf("Addrs : %v\n", ipaddr)
|
||||
continue
|
||||
}
|
||||
fmt.Printf(" %v\n", ipaddr)
|
||||
}
|
||||
for i, dns := range currentDNS(i.Interface) {
|
||||
if i == 0 {
|
||||
fmt.Printf("DNS : %s\n", dns)
|
||||
continue
|
||||
}
|
||||
fmt.Printf(" : %s\n", dns)
|
||||
}
|
||||
println()
|
||||
})
|
||||
if err != nil {
|
||||
mainLog.Error().Msg(err.Error())
|
||||
}
|
||||
},
|
||||
}
|
||||
interfacesCmd := &cobra.Command{
|
||||
Use: "interfaces",
|
||||
Short: "Manage network interfaces",
|
||||
Args: cobra.OnlyValidArgs,
|
||||
ValidArgs: []string{
|
||||
listIfacesCmd.Use,
|
||||
},
|
||||
}
|
||||
interfacesCmd.AddCommand(listIfacesCmd)
|
||||
|
||||
serviceCmd := &cobra.Command{
|
||||
Use: "service",
|
||||
Short: "Manage ctrld service",
|
||||
Args: cobra.OnlyValidArgs,
|
||||
ValidArgs: []string{
|
||||
statusCmd.Use,
|
||||
stopCmd.Use,
|
||||
restartCmd.Use,
|
||||
statusCmd.Use,
|
||||
uninstallCmd.Use,
|
||||
interfacesCmd.Use,
|
||||
},
|
||||
}
|
||||
serviceCmd.AddCommand(startCmd)
|
||||
serviceCmd.AddCommand(stopCmd)
|
||||
serviceCmd.AddCommand(restartCmd)
|
||||
serviceCmd.AddCommand(statusCmd)
|
||||
serviceCmd.AddCommand(uninstallCmd)
|
||||
serviceCmd.AddCommand(interfacesCmd)
|
||||
rootCmd.AddCommand(serviceCmd)
|
||||
startCmdAlias := &cobra.Command{
|
||||
PreRun: func(cmd *cobra.Command, args []string) {
|
||||
initConsoleLogging()
|
||||
checkHasElevatedPrivilege()
|
||||
},
|
||||
Use: "start",
|
||||
Short: "Quick start service and configure DNS on interface",
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
if !cmd.Flags().Changed("iface") {
|
||||
os.Args = append(os.Args, "--iface="+ifaceStartStop)
|
||||
}
|
||||
iface = ifaceStartStop
|
||||
startCmd.Run(cmd, args)
|
||||
},
|
||||
}
|
||||
startCmdAlias.Flags().StringVarP(&ifaceStartStop, "iface", "", "auto", `Update DNS setting for iface, "auto" means the default interface gateway`)
|
||||
startCmdAlias.Flags().AddFlagSet(startCmd.Flags())
|
||||
rootCmd.AddCommand(startCmdAlias)
|
||||
stopCmdAlias := &cobra.Command{
|
||||
PreRun: func(cmd *cobra.Command, args []string) {
|
||||
initConsoleLogging()
|
||||
checkHasElevatedPrivilege()
|
||||
},
|
||||
Use: "stop",
|
||||
Short: "Quick stop service and remove DNS from interface",
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
if !cmd.Flags().Changed("iface") {
|
||||
os.Args = append(os.Args, "--iface="+ifaceStartStop)
|
||||
}
|
||||
iface = ifaceStartStop
|
||||
stopCmd.Run(cmd, args)
|
||||
},
|
||||
}
|
||||
stopCmdAlias.Flags().StringVarP(&ifaceStartStop, "iface", "", "auto", `Reset DNS setting for iface, "auto" means the default interface gateway`)
|
||||
stopCmdAlias.Flags().AddFlagSet(stopCmd.Flags())
|
||||
rootCmd.AddCommand(stopCmdAlias)
|
||||
}
|
||||
|
||||
func writeConfigFile() error {
|
||||
if cfu := v.ConfigFileUsed(); cfu != "" {
|
||||
defaultConfigFile = cfu
|
||||
} else if configPath != "" {
|
||||
defaultConfigFile = configPath
|
||||
}
|
||||
f, err := os.OpenFile(defaultConfigFile, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, os.FileMode(0o644))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer f.Close()
|
||||
if cdUID != "" {
|
||||
if _, err := f.WriteString("# AUTO-GENERATED VIA CD FLAG - DO NOT MODIFY\n\n"); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
enc := toml.NewEncoder(f).SetIndentTables(true)
|
||||
if err := enc.Encode(&cfg); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := f.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func readConfigFile(writeDefaultConfig bool) bool {
|
||||
// If err == nil, there's a config supplied via `--config`, no default config written.
|
||||
err := v.ReadInConfig()
|
||||
if err == nil {
|
||||
mainLog.Info().Msg("loading config file from: " + v.ConfigFileUsed())
|
||||
defaultConfigFile = v.ConfigFileUsed()
|
||||
return true
|
||||
}
|
||||
|
||||
if !writeDefaultConfig {
|
||||
return false
|
||||
}
|
||||
|
||||
// If error is viper.ConfigFileNotFoundError, write default config.
|
||||
if _, ok := err.(viper.ConfigFileNotFoundError); ok {
|
||||
if err := v.Unmarshal(&cfg); err != nil {
|
||||
mainLog.Fatal().Msgf("failed to unmarshal default config: %v", err)
|
||||
}
|
||||
if err := writeConfigFile(); err != nil {
|
||||
mainLog.Fatal().Msgf("failed to write default config file: %v", err)
|
||||
} else {
|
||||
fp, err := filepath.Abs(defaultConfigFile)
|
||||
if err != nil {
|
||||
mainLog.Fatal().Msgf("failed to get default config file path: %v", err)
|
||||
}
|
||||
mainLog.Info().Msg("writing default config file to: " + fp)
|
||||
}
|
||||
defaultConfigWritten = true
|
||||
return false
|
||||
}
|
||||
// Otherwise, report fatal error and exit.
|
||||
mainLog.Fatal().Msgf("failed to decode config file: %v", err)
|
||||
return false
|
||||
}
|
||||
|
||||
func readBase64Config(configBase64 string) {
|
||||
if configBase64 == "" {
|
||||
return
|
||||
}
|
||||
configStr, err := base64.StdEncoding.DecodeString(configBase64)
|
||||
if err != nil {
|
||||
mainLog.Fatal().Msgf("invalid base64 config: %v", err)
|
||||
}
|
||||
if err := v.ReadConfig(bytes.NewReader(configStr)); err != nil {
|
||||
mainLog.Fatal().Msgf("failed to read base64 config: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func processNoConfigFlags(noConfigStart bool) {
|
||||
if !noConfigStart {
|
||||
return
|
||||
}
|
||||
if listenAddress == "" || primaryUpstream == "" {
|
||||
mainLog.Fatal().Msg(`"listen" and "primary_upstream" flags must be set in no config mode`)
|
||||
}
|
||||
processListenFlag()
|
||||
|
||||
endpointAndTyp := func(endpoint string) (string, string) {
|
||||
typ := ctrld.ResolverTypeFromEndpoint(endpoint)
|
||||
return strings.TrimPrefix(endpoint, "quic://"), typ
|
||||
}
|
||||
pEndpoint, pType := endpointAndTyp(primaryUpstream)
|
||||
upstream := map[string]*ctrld.UpstreamConfig{
|
||||
"0": {
|
||||
Name: pEndpoint,
|
||||
Endpoint: pEndpoint,
|
||||
Type: pType,
|
||||
Timeout: 5000,
|
||||
},
|
||||
}
|
||||
if secondaryUpstream != "" {
|
||||
sEndpoint, sType := endpointAndTyp(secondaryUpstream)
|
||||
upstream["1"] = &ctrld.UpstreamConfig{
|
||||
Name: sEndpoint,
|
||||
Endpoint: sEndpoint,
|
||||
Type: sType,
|
||||
Timeout: 5000,
|
||||
}
|
||||
rules := make([]ctrld.Rule, 0, len(domains))
|
||||
for _, domain := range domains {
|
||||
rules = append(rules, ctrld.Rule{domain: []string{"upstream.1"}})
|
||||
}
|
||||
lc := v.Get("listener").(map[string]*ctrld.ListenerConfig)["0"]
|
||||
lc.Policy = &ctrld.ListenerPolicyConfig{Name: "My Policy", Rules: rules}
|
||||
}
|
||||
v.Set("upstream", upstream)
|
||||
}
|
||||
|
||||
func processCDFlags() {
|
||||
if cdUID == "" {
|
||||
return
|
||||
}
|
||||
if iface == "" {
|
||||
iface = "auto"
|
||||
}
|
||||
logger := mainLog.With().Str("mode", "cd").Logger()
|
||||
logger.Info().Msgf("fetching Controld D configuration from API: %s", cdUID)
|
||||
resolverConfig, err := controld.FetchResolverConfig(cdUID, rootCmd.Version, cdDev)
|
||||
if uer, ok := err.(*controld.UtilityErrorResponse); ok && uer.ErrorField.Code == controld.InvalidConfigCode {
|
||||
s, err := service.New(&prog{}, svcConfig)
|
||||
if err != nil {
|
||||
logger.Warn().Err(err).Msg("failed to create new service")
|
||||
return
|
||||
}
|
||||
|
||||
if netIface, _ := netInterface(iface); netIface != nil {
|
||||
if err := restoreNetworkManager(); err != nil {
|
||||
logger.Error().Err(err).Msg("could not restore NetworkManager")
|
||||
return
|
||||
}
|
||||
logger.Debug().Str("iface", netIface.Name).Msg("Restoring DNS for interface")
|
||||
if err := resetDNS(netIface); err != nil {
|
||||
logger.Warn().Err(err).Msg("something went wrong while restoring DNS")
|
||||
} else {
|
||||
logger.Debug().Str("iface", netIface.Name).Msg("Restoring DNS successfully")
|
||||
}
|
||||
}
|
||||
|
||||
tasks := []task{{s.Uninstall, true}}
|
||||
if doTasks(tasks) {
|
||||
logger.Info().Msg("uninstalled service")
|
||||
}
|
||||
logger.Fatal().Err(uer).Msg("failed to fetch resolver config")
|
||||
}
|
||||
if err != nil {
|
||||
logger.Warn().Err(err).Msg("could not fetch resolver config")
|
||||
return
|
||||
}
|
||||
|
||||
logger.Info().Msg("generating ctrld config from Control-D configuration")
|
||||
if resolverConfig.Ctrld.CustomConfig != "" {
|
||||
logger.Info().Msg("using defined custom config of Control-D resolver")
|
||||
readBase64Config(resolverConfig.Ctrld.CustomConfig)
|
||||
if err := v.Unmarshal(&cfg); err != nil {
|
||||
mainLog.Fatal().Msgf("failed to unmarshal config: %v", err)
|
||||
}
|
||||
for _, listener := range cfg.Listener {
|
||||
if listener.IP == "" {
|
||||
listener.IP = randomLocalIP()
|
||||
}
|
||||
if listener.Port == 0 {
|
||||
listener.Port = 53
|
||||
}
|
||||
}
|
||||
// On router, we want to keep the listener address point to dnsmasq listener, aka 127.0.0.1:53.
|
||||
if router.Name() != "" {
|
||||
if lc := cfg.Listener["0"]; lc != nil {
|
||||
lc.IP = "127.0.0.1"
|
||||
lc.Port = 53
|
||||
}
|
||||
}
|
||||
} else {
|
||||
cfg = ctrld.Config{}
|
||||
cfg.Network = make(map[string]*ctrld.NetworkConfig)
|
||||
cfg.Network["0"] = &ctrld.NetworkConfig{
|
||||
Name: "Network 0",
|
||||
Cidrs: []string{"0.0.0.0/0"},
|
||||
}
|
||||
cfg.Upstream = make(map[string]*ctrld.UpstreamConfig)
|
||||
cfg.Upstream["0"] = &ctrld.UpstreamConfig{
|
||||
Endpoint: resolverConfig.DOH,
|
||||
Type: ctrld.ResolverTypeDOH,
|
||||
Timeout: 5000,
|
||||
}
|
||||
rules := make([]ctrld.Rule, 0, len(resolverConfig.Exclude))
|
||||
for _, domain := range resolverConfig.Exclude {
|
||||
rules = append(rules, ctrld.Rule{domain: []string{}})
|
||||
}
|
||||
cfg.Listener = make(map[string]*ctrld.ListenerConfig)
|
||||
cfg.Listener["0"] = &ctrld.ListenerConfig{
|
||||
IP: "127.0.0.1",
|
||||
Port: 53,
|
||||
Policy: &ctrld.ListenerPolicyConfig{
|
||||
Name: "My Policy",
|
||||
Rules: rules,
|
||||
},
|
||||
}
|
||||
processLogAndCacheFlags()
|
||||
}
|
||||
|
||||
if err := writeConfigFile(); err != nil {
|
||||
logger.Fatal().Err(err).Msg("failed to write config file")
|
||||
} else {
|
||||
logger.Info().Msg("writing config file to: " + defaultConfigFile)
|
||||
}
|
||||
}
|
||||
|
||||
func processListenFlag() {
|
||||
if listenAddress == "" {
|
||||
return
|
||||
}
|
||||
host, portStr, err := net.SplitHostPort(listenAddress)
|
||||
if err != nil {
|
||||
mainLog.Fatal().Msgf("invalid listener address: %v", err)
|
||||
}
|
||||
port, err := strconv.Atoi(portStr)
|
||||
if err != nil {
|
||||
mainLog.Fatal().Msgf("invalid port number: %v", err)
|
||||
}
|
||||
lc := &ctrld.ListenerConfig{
|
||||
IP: host,
|
||||
Port: port,
|
||||
}
|
||||
v.Set("listener", map[string]*ctrld.ListenerConfig{
|
||||
"0": lc,
|
||||
})
|
||||
}
|
||||
|
||||
func processLogAndCacheFlags() {
|
||||
if logPath != "" {
|
||||
cfg.Service.LogLevel = "debug"
|
||||
cfg.Service.LogPath = logPath
|
||||
}
|
||||
|
||||
if cacheSize != 0 {
|
||||
cfg.Service.CacheEnable = true
|
||||
cfg.Service.CacheSize = cacheSize
|
||||
}
|
||||
v.Set("service", cfg.Service)
|
||||
}
|
||||
|
||||
func netInterface(ifaceName string) (*net.Interface, error) {
|
||||
if ifaceName == "auto" {
|
||||
ifaceName = defaultIfaceName()
|
||||
}
|
||||
var iface *net.Interface
|
||||
err := interfaces.ForeachInterface(func(i interfaces.Interface, prefixes []netip.Prefix) {
|
||||
if i.Name == ifaceName {
|
||||
iface = i.Interface
|
||||
}
|
||||
})
|
||||
if iface == nil {
|
||||
return nil, errors.New("interface not found")
|
||||
}
|
||||
if err := patchNetIfaceName(iface); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return iface, err
|
||||
}
|
||||
|
||||
func defaultIfaceName() string {
|
||||
dri, err := interfaces.DefaultRouteInterface()
|
||||
if err != nil {
|
||||
// On WSL 1, the route table does not have any default route. But the fact that
|
||||
// it only uses /etc/resolv.conf for setup DNS, so we can use "lo" here.
|
||||
if oi := osinfo.New(); strings.Contains(oi.String(), "Microsoft") {
|
||||
return "lo"
|
||||
}
|
||||
mainLog.Fatal().Err(err).Msg("failed to get default route interface")
|
||||
}
|
||||
return dri
|
||||
}
|
||||
|
||||
func selfCheckStatus(status service.Status, domain string) service.Status {
|
||||
if domain == "" {
|
||||
// Nothing to do, return the status as-is.
|
||||
return status
|
||||
}
|
||||
c := new(dns.Client)
|
||||
bo := backoff.NewBackoff("self-check", logf, 10*time.Second)
|
||||
bo.LogLongerThan = 500 * time.Millisecond
|
||||
ctx := context.Background()
|
||||
maxAttempts := 20
|
||||
mainLog.Debug().Msg("Performing self-check")
|
||||
var (
|
||||
lcChanged map[string]*ctrld.ListenerConfig
|
||||
mu sync.Mutex
|
||||
)
|
||||
v.OnConfigChange(func(in fsnotify.Event) {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
if err := v.UnmarshalKey("listener", &lcChanged); err != nil {
|
||||
mainLog.Error().Msgf("failed to unmarshal listener config: %v", err)
|
||||
return
|
||||
}
|
||||
})
|
||||
v.WatchConfig()
|
||||
for i := 0; i < maxAttempts; i++ {
|
||||
lc := cfg.Listener["0"]
|
||||
mu.Lock()
|
||||
if lcChanged != nil {
|
||||
lc = lcChanged["0"]
|
||||
}
|
||||
mu.Unlock()
|
||||
m := new(dns.Msg)
|
||||
m.SetQuestion(domain+".", dns.TypeA)
|
||||
m.RecursionDesired = true
|
||||
r, _, err := c.ExchangeContext(ctx, m, net.JoinHostPort(lc.IP, strconv.Itoa(lc.Port)))
|
||||
if r != nil && r.Rcode == dns.RcodeSuccess && len(r.Answer) > 0 {
|
||||
mainLog.Debug().Msgf("self-check against %q succeeded", domain)
|
||||
return status
|
||||
}
|
||||
bo.BackOff(ctx, fmt.Errorf("ExchangeContext: %w", err))
|
||||
}
|
||||
mainLog.Debug().Msgf("self-check against %q failed", domain)
|
||||
return service.StatusUnknown
|
||||
}
|
||||
|
||||
func unsupportedPlatformHelp(cmd *cobra.Command) {
|
||||
mainLog.Error().Msg("Unsupported or incorrectly chosen router platform. Please open an issue and provide all relevant information: https://github.com/Control-D-Inc/ctrld/issues/new")
|
||||
}
|
||||
|
||||
func userHomeDir() (string, error) {
|
||||
switch router.Name() {
|
||||
case router.DDWrt, router.Merlin, router.Tomato:
|
||||
exe, err := os.Executable()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return filepath.Dir(exe), nil
|
||||
}
|
||||
// viper will expand for us.
|
||||
if runtime.GOOS == "windows" {
|
||||
return os.UserHomeDir()
|
||||
}
|
||||
dir := "/etc/controld"
|
||||
if err := os.MkdirAll(dir, 0750); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return dir, nil
|
||||
}
|
||||
|
||||
func tryReadingConfig(writeDefaultConfig bool) {
|
||||
configs := []struct {
|
||||
name string
|
||||
written bool
|
||||
}{
|
||||
// For compatibility, we check for config.toml first, but only read it if exists.
|
||||
{"config", false},
|
||||
{"ctrld", writeDefaultConfig},
|
||||
}
|
||||
|
||||
dir, err := userHomeDir()
|
||||
if err != nil {
|
||||
mainLog.Fatal().Msgf("failed to get config dir: %v", err)
|
||||
}
|
||||
for _, config := range configs {
|
||||
ctrld.SetConfigNameWithPath(v, config.name, dir)
|
||||
v.SetConfigFile(configPath)
|
||||
if readConfigFile(config.written) {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,100 +0,0 @@
|
||||
//go:build linux || freebsd
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"os"
|
||||
"os/exec"
|
||||
"strings"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld/internal/router"
|
||||
)
|
||||
|
||||
func initRouterCLI() {
|
||||
validArgs := append(router.SupportedPlatforms(), "auto")
|
||||
var b strings.Builder
|
||||
b.WriteString("Auto-setup Control D on a router.\n\nSupported platforms:\n\n")
|
||||
for _, arg := range validArgs {
|
||||
b.WriteString(" ₒ ")
|
||||
b.WriteString(arg)
|
||||
if arg == "auto" {
|
||||
b.WriteString(" - detect the platform you are running on")
|
||||
}
|
||||
b.WriteString("\n")
|
||||
}
|
||||
|
||||
routerCmd := &cobra.Command{
|
||||
Use: "setup",
|
||||
Short: b.String(),
|
||||
PreRun: func(cmd *cobra.Command, args []string) {
|
||||
initConsoleLogging()
|
||||
},
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
if len(args) == 0 {
|
||||
_ = cmd.Help()
|
||||
return
|
||||
}
|
||||
if len(args) != 1 {
|
||||
_ = cmd.Help()
|
||||
return
|
||||
}
|
||||
platform := args[0]
|
||||
if platform == "auto" {
|
||||
platform = router.Name()
|
||||
}
|
||||
if !router.IsSupported(platform) {
|
||||
unsupportedPlatformHelp(cmd)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
exe, err := os.Executable()
|
||||
if err != nil {
|
||||
mainLog.Fatal().Msgf("could not find executable path: %v", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
cmdArgs := []string{"start"}
|
||||
cmdArgs = append(cmdArgs, osArgs(platform)...)
|
||||
cmdArgs = append(cmdArgs, "--router")
|
||||
command := exec.Command(exe, cmdArgs...)
|
||||
command.Stdout = os.Stdout
|
||||
command.Stderr = os.Stderr
|
||||
command.Stdin = os.Stdin
|
||||
if err := command.Run(); err != nil {
|
||||
mainLog.Fatal().Msg(err.Error())
|
||||
}
|
||||
},
|
||||
}
|
||||
// Keep these flags in sync with startCmd, except for "--router".
|
||||
routerCmd.Flags().StringVarP(&configPath, "config", "c", "", "Path to config file")
|
||||
routerCmd.Flags().StringVarP(&configBase64, "base64_config", "", "", "Base64 encoded config")
|
||||
routerCmd.Flags().StringVarP(&listenAddress, "listen", "", "", "Listener address and port, in format: address:port")
|
||||
routerCmd.Flags().StringVarP(&primaryUpstream, "primary_upstream", "", "", "Primary upstream endpoint")
|
||||
routerCmd.Flags().StringVarP(&secondaryUpstream, "secondary_upstream", "", "", "Secondary upstream endpoint")
|
||||
routerCmd.Flags().StringSliceVarP(&domains, "domains", "", nil, "List of domain to apply in a split DNS policy")
|
||||
routerCmd.Flags().StringVarP(&logPath, "log", "", "", "Path to log file")
|
||||
routerCmd.Flags().IntVarP(&cacheSize, "cache_size", "", 0, "Enable cache with size items")
|
||||
routerCmd.Flags().StringVarP(&cdUID, "cd", "", "", "Control D resolver uid")
|
||||
routerCmd.Flags().BoolVarP(&cdDev, "dev", "", false, "Use Control D dev resolver/domain")
|
||||
_ = routerCmd.Flags().MarkHidden("dev")
|
||||
routerCmd.Flags().StringVarP(&iface, "iface", "", "", `Update DNS setting for iface, "auto" means the default interface gateway`)
|
||||
|
||||
tmpl := routerCmd.UsageTemplate()
|
||||
tmpl = strings.Replace(tmpl, "{{.UseLine}}", "{{.UseLine}} [platform]", 1)
|
||||
routerCmd.SetUsageTemplate(tmpl)
|
||||
rootCmd.AddCommand(routerCmd)
|
||||
}
|
||||
|
||||
func osArgs(platform string) []string {
|
||||
args := os.Args[2:]
|
||||
n := 0
|
||||
for _, x := range args {
|
||||
if x != platform && x != "auto" {
|
||||
args[n] = x
|
||||
n++
|
||||
}
|
||||
}
|
||||
return args[:n]
|
||||
}
|
||||
@@ -1,5 +0,0 @@
|
||||
//go:build !linux && !freebsd
|
||||
|
||||
package main
|
||||
|
||||
func initRouterCLI() {}
|
||||
@@ -1,23 +0,0 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func Test_writeConfigFile(t *testing.T) {
|
||||
tmpdir := t.TempDir()
|
||||
// simulate --config CLI flag by setting configPath manually.
|
||||
configPath = filepath.Join(tmpdir, "ctrld.toml")
|
||||
_, err := os.Stat(configPath)
|
||||
assert.True(t, os.IsNotExist(err))
|
||||
|
||||
assert.NoError(t, writeConfigFile())
|
||||
|
||||
_, err = os.Stat(configPath)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
@@ -1,4 +0,0 @@
|
||||
package main
|
||||
|
||||
//lint:ignore U1000 use in os_linux.go
|
||||
type getDNS func(iface string) []string
|
||||
@@ -1,519 +0,0 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"net"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"golang.org/x/sync/errgroup"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld"
|
||||
"github.com/Control-D-Inc/ctrld/internal/dnscache"
|
||||
ctrldnet "github.com/Control-D-Inc/ctrld/internal/net"
|
||||
"github.com/Control-D-Inc/ctrld/internal/router"
|
||||
)
|
||||
|
||||
const (
|
||||
staleTTL = 60 * time.Second
|
||||
// EDNS0_OPTION_MAC is dnsmasq EDNS0 code for adding mac option.
|
||||
// https://thekelleys.org.uk/gitweb/?p=dnsmasq.git;a=blob;f=src/dns-protocol.h;h=76ac66a8c28317e9c121a74ab5fd0e20f6237dc8;hb=HEAD#l81
|
||||
// This is also dns.EDNS0LOCALSTART, but define our own constant here for clarification.
|
||||
EDNS0_OPTION_MAC = 0xFDE9
|
||||
)
|
||||
|
||||
var osUpstreamConfig = &ctrld.UpstreamConfig{
|
||||
Name: "OS resolver",
|
||||
Type: ctrld.ResolverTypeOS,
|
||||
Timeout: 2000,
|
||||
}
|
||||
|
||||
func (p *prog) serveDNS(listenerNum string) error {
|
||||
listenerConfig := p.cfg.Listener[listenerNum]
|
||||
// make sure ip is allocated
|
||||
if allocErr := p.allocateIP(listenerConfig.IP); allocErr != nil {
|
||||
mainLog.Error().Err(allocErr).Str("ip", listenerConfig.IP).Msg("serveUDP: failed to allocate listen ip")
|
||||
return allocErr
|
||||
}
|
||||
var failoverRcodes []int
|
||||
if listenerConfig.Policy != nil {
|
||||
failoverRcodes = listenerConfig.Policy.FailoverRcodeNumbers
|
||||
}
|
||||
handler := dns.HandlerFunc(func(w dns.ResponseWriter, m *dns.Msg) {
|
||||
q := m.Question[0]
|
||||
domain := canonicalName(q.Name)
|
||||
reqId := requestID()
|
||||
remoteAddr := spoofRemoteAddr(w.RemoteAddr(), router.GetClientInfoByMac(macFromMsg(m)))
|
||||
fmtSrcToDest := fmtRemoteToLocal(listenerNum, remoteAddr.String(), w.LocalAddr().String())
|
||||
t := time.Now()
|
||||
ctx := context.WithValue(context.Background(), ctrld.ReqIdCtxKey{}, reqId)
|
||||
ctrld.Log(ctx, mainLog.Debug(), "%s received query: %s %s", fmtSrcToDest, dns.TypeToString[q.Qtype], domain)
|
||||
upstreams, matched := p.upstreamFor(ctx, listenerNum, listenerConfig, remoteAddr, domain)
|
||||
var answer *dns.Msg
|
||||
if !matched && listenerConfig.Restricted {
|
||||
answer = new(dns.Msg)
|
||||
answer.SetRcode(m, dns.RcodeRefused)
|
||||
|
||||
} else {
|
||||
answer = p.proxy(ctx, upstreams, failoverRcodes, m)
|
||||
rtt := time.Since(t)
|
||||
ctrld.Log(ctx, mainLog.Debug(), "received response of %d bytes in %s", answer.Len(), rtt)
|
||||
}
|
||||
if err := w.WriteMsg(answer); err != nil {
|
||||
ctrld.Log(ctx, mainLog.Error().Err(err), "serveUDP: failed to send DNS response to client")
|
||||
}
|
||||
})
|
||||
|
||||
g, ctx := errgroup.WithContext(context.Background())
|
||||
for _, proto := range []string{"udp", "tcp"} {
|
||||
proto := proto
|
||||
if needLocalIPv6Listener() {
|
||||
g.Go(func() error {
|
||||
s, errCh := runDNSServer(net.JoinHostPort("::1", strconv.Itoa(listenerConfig.Port)), proto, handler)
|
||||
defer s.Shutdown()
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
case err := <-errCh:
|
||||
// Local ipv6 listener should not terminate ctrld.
|
||||
// It's a workaround for a quirk on Windows.
|
||||
mainLog.Warn().Err(err).Msg("local ipv6 listener failed")
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
g.Go(func() error {
|
||||
s, errCh := runDNSServer(dnsListenAddress(listenerNum, listenerConfig), proto, handler)
|
||||
defer s.Shutdown()
|
||||
if listenerConfig.Port == 0 {
|
||||
switch s.Net {
|
||||
case "udp":
|
||||
mainLog.Info().Msgf("Random port chosen for udp listener.%s: %s", listenerNum, s.PacketConn.LocalAddr())
|
||||
case "tcp":
|
||||
mainLog.Info().Msgf("Random port chosen for tcp listener.%s: %s", listenerNum, s.Listener.Addr())
|
||||
}
|
||||
}
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil
|
||||
case err := <-errCh:
|
||||
return err
|
||||
}
|
||||
})
|
||||
}
|
||||
return g.Wait()
|
||||
}
|
||||
|
||||
// upstreamFor returns the list of upstreams for resolving the given domain,
|
||||
// matching by policies defined in the listener config. The second return value
|
||||
// reports whether the domain matches the policy.
|
||||
//
|
||||
// Though domain policy has higher priority than network policy, it is still
|
||||
// processed later, because policy logging want to know whether a network rule
|
||||
// is disregarded in favor of the domain level rule.
|
||||
func (p *prog) upstreamFor(ctx context.Context, defaultUpstreamNum string, lc *ctrld.ListenerConfig, addr net.Addr, domain string) ([]string, bool) {
|
||||
upstreams := []string{"upstream." + defaultUpstreamNum}
|
||||
matchedPolicy := "no policy"
|
||||
matchedNetwork := "no network"
|
||||
matchedRule := "no rule"
|
||||
matched := false
|
||||
|
||||
defer func() {
|
||||
if !matched && lc.Restricted {
|
||||
ctrld.Log(ctx, mainLog.Info(), "query refused, %s does not match any network policy", addr.String())
|
||||
return
|
||||
}
|
||||
ctrld.Log(ctx, mainLog.Info(), "%s, %s, %s -> %v", matchedPolicy, matchedNetwork, matchedRule, upstreams)
|
||||
}()
|
||||
|
||||
if lc.Policy == nil {
|
||||
return upstreams, false
|
||||
}
|
||||
|
||||
do := func(policyUpstreams []string) {
|
||||
upstreams = append([]string(nil), policyUpstreams...)
|
||||
}
|
||||
|
||||
var networkTargets []string
|
||||
var sourceIP net.IP
|
||||
switch addr := addr.(type) {
|
||||
case *net.UDPAddr:
|
||||
sourceIP = addr.IP
|
||||
case *net.TCPAddr:
|
||||
sourceIP = addr.IP
|
||||
}
|
||||
|
||||
networkRules:
|
||||
for _, rule := range lc.Policy.Networks {
|
||||
for source, targets := range rule {
|
||||
networkNum := strings.TrimPrefix(source, "network.")
|
||||
nc := p.cfg.Network[networkNum]
|
||||
if nc == nil {
|
||||
continue
|
||||
}
|
||||
for _, ipNet := range nc.IPNets {
|
||||
if ipNet.Contains(sourceIP) {
|
||||
matchedPolicy = lc.Policy.Name
|
||||
matchedNetwork = source
|
||||
networkTargets = targets
|
||||
matched = true
|
||||
break networkRules
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, rule := range lc.Policy.Rules {
|
||||
// There's only one entry per rule, config validation ensures this.
|
||||
for source, targets := range rule {
|
||||
if source == domain || wildcardMatches(source, domain) {
|
||||
matchedPolicy = lc.Policy.Name
|
||||
if len(networkTargets) > 0 {
|
||||
matchedNetwork += " (unenforced)"
|
||||
}
|
||||
matchedRule = source
|
||||
do(targets)
|
||||
matched = true
|
||||
return upstreams, matched
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if matched {
|
||||
do(networkTargets)
|
||||
}
|
||||
|
||||
return upstreams, matched
|
||||
}
|
||||
|
||||
func (p *prog) proxy(ctx context.Context, upstreams []string, failoverRcodes []int, msg *dns.Msg) *dns.Msg {
|
||||
var staleAnswer *dns.Msg
|
||||
serveStaleCache := p.cache != nil && p.cfg.Service.CacheServeStale
|
||||
upstreamConfigs := p.upstreamConfigsFromUpstreamNumbers(upstreams)
|
||||
if len(upstreamConfigs) == 0 {
|
||||
upstreamConfigs = []*ctrld.UpstreamConfig{osUpstreamConfig}
|
||||
upstreams = []string{"upstream.os"}
|
||||
}
|
||||
// Inverse query should not be cached: https://www.rfc-editor.org/rfc/rfc1035#section-7.4
|
||||
if p.cache != nil && msg.Question[0].Qtype != dns.TypePTR {
|
||||
for _, upstream := range upstreams {
|
||||
cachedValue := p.cache.Get(dnscache.NewKey(msg, upstream))
|
||||
if cachedValue == nil {
|
||||
continue
|
||||
}
|
||||
answer := cachedValue.Msg.Copy()
|
||||
answer.SetRcode(msg, answer.Rcode)
|
||||
now := time.Now()
|
||||
if cachedValue.Expire.After(now) {
|
||||
ctrld.Log(ctx, mainLog.Debug(), "hit cached response")
|
||||
setCachedAnswerTTL(answer, now, cachedValue.Expire)
|
||||
return answer
|
||||
}
|
||||
staleAnswer = answer
|
||||
}
|
||||
}
|
||||
resolve1 := func(n int, upstreamConfig *ctrld.UpstreamConfig, msg *dns.Msg) (*dns.Msg, error) {
|
||||
ctrld.Log(ctx, mainLog.Debug(), "sending query to %s: %s", upstreams[n], upstreamConfig.Name)
|
||||
dnsResolver, err := ctrld.NewResolver(upstreamConfig)
|
||||
if err != nil {
|
||||
ctrld.Log(ctx, mainLog.Error().Err(err), "failed to create resolver")
|
||||
return nil, err
|
||||
}
|
||||
resolveCtx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
if upstreamConfig.Timeout > 0 {
|
||||
timeoutCtx, cancel := context.WithTimeout(resolveCtx, time.Millisecond*time.Duration(upstreamConfig.Timeout))
|
||||
defer cancel()
|
||||
resolveCtx = timeoutCtx
|
||||
}
|
||||
return dnsResolver.Resolve(resolveCtx, msg)
|
||||
}
|
||||
resolve := func(n int, upstreamConfig *ctrld.UpstreamConfig, msg *dns.Msg) *dns.Msg {
|
||||
if upstreamConfig.UpstreamSendClientInfo() {
|
||||
ci := router.GetClientInfoByMac(macFromMsg(msg))
|
||||
if ci != nil {
|
||||
ctrld.Log(ctx, mainLog.Debug(), "including client info with the request")
|
||||
ctx = context.WithValue(ctx, ctrld.ClientInfoCtxKey{}, ci)
|
||||
}
|
||||
}
|
||||
answer, err := resolve1(n, upstreamConfig, msg)
|
||||
// Only do re-bootstrapping if bootstrap ip is not explicitly set by user.
|
||||
if err != nil && upstreamConfig.BootstrapIP == "" {
|
||||
ctrld.Log(ctx, mainLog.Debug().Err(err), "could not resolve query on first attempt, retrying...")
|
||||
// If any error occurred, re-bootstrap transport/ip, retry the request.
|
||||
upstreamConfig.ReBootstrap()
|
||||
answer, err = resolve1(n, upstreamConfig, msg)
|
||||
if err == nil {
|
||||
return answer
|
||||
}
|
||||
ctrld.Log(ctx, mainLog.Error().Err(err), "failed to resolve query")
|
||||
return nil
|
||||
}
|
||||
return answer
|
||||
}
|
||||
for n, upstreamConfig := range upstreamConfigs {
|
||||
if upstreamConfig == nil {
|
||||
continue
|
||||
}
|
||||
answer := resolve(n, upstreamConfig, msg)
|
||||
if answer == nil {
|
||||
if serveStaleCache && staleAnswer != nil {
|
||||
ctrld.Log(ctx, mainLog.Debug(), "serving stale cached response")
|
||||
now := time.Now()
|
||||
setCachedAnswerTTL(staleAnswer, now, now.Add(staleTTL))
|
||||
return staleAnswer
|
||||
}
|
||||
continue
|
||||
}
|
||||
if answer.Rcode != dns.RcodeSuccess && len(upstreamConfigs) > 1 && containRcode(failoverRcodes, answer.Rcode) {
|
||||
ctrld.Log(ctx, mainLog.Debug(), "failover rcode matched, process to next upstream")
|
||||
continue
|
||||
}
|
||||
|
||||
// set compression, as it is not set by default when unpacking
|
||||
answer.Compress = true
|
||||
|
||||
if p.cache != nil {
|
||||
ttl := ttlFromMsg(answer)
|
||||
now := time.Now()
|
||||
expired := now.Add(time.Duration(ttl) * time.Second)
|
||||
if cachedTTL := p.cfg.Service.CacheTTLOverride; cachedTTL > 0 {
|
||||
expired = now.Add(time.Duration(cachedTTL) * time.Second)
|
||||
}
|
||||
setCachedAnswerTTL(answer, now, expired)
|
||||
p.cache.Add(dnscache.NewKey(msg, upstreams[n]), dnscache.NewValue(answer, expired))
|
||||
ctrld.Log(ctx, mainLog.Debug(), "add cached response")
|
||||
}
|
||||
return answer
|
||||
}
|
||||
ctrld.Log(ctx, mainLog.Error(), "all upstreams failed")
|
||||
answer := new(dns.Msg)
|
||||
answer.SetRcode(msg, dns.RcodeServerFailure)
|
||||
return answer
|
||||
}
|
||||
|
||||
func (p *prog) upstreamConfigsFromUpstreamNumbers(upstreams []string) []*ctrld.UpstreamConfig {
|
||||
upstreamConfigs := make([]*ctrld.UpstreamConfig, 0, len(upstreams))
|
||||
for _, upstream := range upstreams {
|
||||
upstreamNum := strings.TrimPrefix(upstream, "upstream.")
|
||||
upstreamConfigs = append(upstreamConfigs, p.cfg.Upstream[upstreamNum])
|
||||
}
|
||||
return upstreamConfigs
|
||||
}
|
||||
|
||||
// canonicalName returns canonical name from FQDN with "." trimmed.
|
||||
func canonicalName(fqdn string) string {
|
||||
q := strings.TrimSpace(fqdn)
|
||||
q = strings.TrimSuffix(q, ".")
|
||||
// https://datatracker.ietf.org/doc/html/rfc4343
|
||||
q = strings.ToLower(q)
|
||||
|
||||
return q
|
||||
}
|
||||
|
||||
func wildcardMatches(wildcard, domain string) bool {
|
||||
// Wildcard match.
|
||||
wildCardParts := strings.Split(wildcard, "*")
|
||||
if len(wildCardParts) != 2 {
|
||||
return false
|
||||
}
|
||||
|
||||
switch {
|
||||
case len(wildCardParts[0]) > 0 && len(wildCardParts[1]) > 0:
|
||||
// Domain must match both prefix and suffix.
|
||||
return strings.HasPrefix(domain, wildCardParts[0]) && strings.HasSuffix(domain, wildCardParts[1])
|
||||
|
||||
case len(wildCardParts[1]) > 0:
|
||||
// Only suffix must match.
|
||||
return strings.HasSuffix(domain, wildCardParts[1])
|
||||
|
||||
case len(wildCardParts[0]) > 0:
|
||||
// Only prefix must match.
|
||||
return strings.HasPrefix(domain, wildCardParts[0])
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func fmtRemoteToLocal(listenerNum, remote, local string) string {
|
||||
return fmt.Sprintf("%s -> listener.%s: %s:", remote, listenerNum, local)
|
||||
}
|
||||
|
||||
func requestID() string {
|
||||
b := make([]byte, 3) // 6 chars
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return hex.EncodeToString(b)
|
||||
}
|
||||
|
||||
func containRcode(rcodes []int, rcode int) bool {
|
||||
for i := range rcodes {
|
||||
if rcodes[i] == rcode {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func setCachedAnswerTTL(answer *dns.Msg, now, expiredTime time.Time) {
|
||||
ttlSecs := expiredTime.Sub(now).Seconds()
|
||||
if ttlSecs < 0 {
|
||||
return
|
||||
}
|
||||
|
||||
ttl := uint32(ttlSecs)
|
||||
for _, rr := range answer.Answer {
|
||||
rr.Header().Ttl = ttl
|
||||
}
|
||||
for _, rr := range answer.Ns {
|
||||
rr.Header().Ttl = ttl
|
||||
}
|
||||
for _, rr := range answer.Extra {
|
||||
if rr.Header().Rrtype != dns.TypeOPT {
|
||||
rr.Header().Ttl = ttl
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func ttlFromMsg(msg *dns.Msg) uint32 {
|
||||
for _, rr := range msg.Answer {
|
||||
return rr.Header().Ttl
|
||||
}
|
||||
for _, rr := range msg.Ns {
|
||||
return rr.Header().Ttl
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func needLocalIPv6Listener() bool {
|
||||
// On Windows, there's no easy way for disabling/removing IPv6 DNS resolver, so we check whether we can
|
||||
// listen on ::1, then spawn a listener for receiving DNS requests.
|
||||
return ctrldnet.SupportsIPv6ListenLocal() && runtime.GOOS == "windows"
|
||||
}
|
||||
|
||||
func dnsListenAddress(lcNum string, lc *ctrld.ListenerConfig) string {
|
||||
if addr := router.ListenAddress(); setupRouter && addr != "" && lcNum == "0" {
|
||||
return addr
|
||||
}
|
||||
return net.JoinHostPort(lc.IP, strconv.Itoa(lc.Port))
|
||||
}
|
||||
|
||||
func macFromMsg(msg *dns.Msg) string {
|
||||
if opt := msg.IsEdns0(); opt != nil {
|
||||
for _, s := range opt.Option {
|
||||
switch e := s.(type) {
|
||||
case *dns.EDNS0_LOCAL:
|
||||
if e.Code == EDNS0_OPTION_MAC {
|
||||
return net.HardwareAddr(e.Data).String()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func spoofRemoteAddr(addr net.Addr, ci *ctrld.ClientInfo) net.Addr {
|
||||
if ci != nil && ci.IP != "" {
|
||||
switch addr := addr.(type) {
|
||||
case *net.UDPAddr:
|
||||
udpAddr := &net.UDPAddr{
|
||||
IP: net.ParseIP(ci.IP),
|
||||
Port: addr.Port,
|
||||
Zone: addr.Zone,
|
||||
}
|
||||
return udpAddr
|
||||
case *net.TCPAddr:
|
||||
udpAddr := &net.TCPAddr{
|
||||
IP: net.ParseIP(ci.IP),
|
||||
Port: addr.Port,
|
||||
Zone: addr.Zone,
|
||||
}
|
||||
return udpAddr
|
||||
}
|
||||
}
|
||||
return addr
|
||||
}
|
||||
|
||||
// runDNSServer starts a DNS server for given address and network,
|
||||
// with the given handler. It ensures the server has started listening.
|
||||
// Any error will be reported to the caller via returned channel.
|
||||
//
|
||||
// It's the caller responsibility to call Shutdown to close the server.
|
||||
func runDNSServer(addr, network string, handler dns.Handler) (*dns.Server, <-chan error) {
|
||||
s := &dns.Server{
|
||||
Addr: addr,
|
||||
Net: network,
|
||||
Handler: handler,
|
||||
}
|
||||
|
||||
waitLock := sync.Mutex{}
|
||||
waitLock.Lock()
|
||||
s.NotifyStartedFunc = waitLock.Unlock
|
||||
|
||||
errCh := make(chan error)
|
||||
go func() {
|
||||
defer close(errCh)
|
||||
if err := s.ListenAndServe(); err != nil {
|
||||
waitLock.Unlock()
|
||||
mainLog.Error().Err(err).Msgf("could not listen and serve on: %s", s.Addr)
|
||||
errCh <- err
|
||||
}
|
||||
}()
|
||||
waitLock.Lock()
|
||||
return s, errCh
|
||||
}
|
||||
|
||||
// runDNSServerForNTPD starts a DNS server listening on router.ListenAddress(). It must only be called when ctrld
|
||||
// running on router, before router.PreRun() to serve DNS request for NTP synchronization. The caller must call
|
||||
// s.Shutdown() explicitly when NTP is synced successfully.
|
||||
func runDNSServerForNTPD(addr string) (*dns.Server, <-chan error) {
|
||||
if addr == "" {
|
||||
return &dns.Server{}, nil
|
||||
}
|
||||
dnsResolver := ctrld.NewBootstrapResolver()
|
||||
s := &dns.Server{
|
||||
Addr: addr,
|
||||
Net: "udp",
|
||||
Handler: dns.HandlerFunc(func(w dns.ResponseWriter, m *dns.Msg) {
|
||||
mainLog.Debug().Msg("Serving query for ntpd")
|
||||
resolveCtx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
if osUpstreamConfig.Timeout > 0 {
|
||||
timeoutCtx, cancel := context.WithTimeout(resolveCtx, time.Millisecond*time.Duration(osUpstreamConfig.Timeout))
|
||||
defer cancel()
|
||||
resolveCtx = timeoutCtx
|
||||
}
|
||||
answer, err := dnsResolver.Resolve(resolveCtx, m)
|
||||
if err != nil {
|
||||
mainLog.Error().Err(err).Msgf("could not resolve: %v", m)
|
||||
return
|
||||
}
|
||||
if err := w.WriteMsg(answer); err != nil {
|
||||
mainLog.Error().Err(err).Msg("runDNSServerForNTPD: failed to send DNS response")
|
||||
}
|
||||
}),
|
||||
}
|
||||
|
||||
waitLock := sync.Mutex{}
|
||||
waitLock.Lock()
|
||||
s.NotifyStartedFunc = waitLock.Unlock
|
||||
|
||||
errCh := make(chan error)
|
||||
go func() {
|
||||
defer close(errCh)
|
||||
if err := s.ListenAndServe(); err != nil {
|
||||
waitLock.Unlock()
|
||||
mainLog.Error().Err(err).Msgf("could not listen and serve on: %s", s.Addr)
|
||||
errCh <- err
|
||||
}
|
||||
}()
|
||||
waitLock.Lock()
|
||||
return s, errCh
|
||||
}
|
||||
@@ -1,218 +0,0 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/Control-D-Inc/ctrld"
|
||||
"github.com/Control-D-Inc/ctrld/internal/dnscache"
|
||||
"github.com/Control-D-Inc/ctrld/testhelper"
|
||||
)
|
||||
|
||||
func Test_wildcardMatches(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
wildcard string
|
||||
domain string
|
||||
match bool
|
||||
}{
|
||||
{"prefix parent should not match", "*.windscribe.com", "windscribe.com", false},
|
||||
{"prefix", "*.windscribe.com", "anything.windscribe.com", true},
|
||||
{"prefix not match other domain", "*.windscribe.com", "example.com", false},
|
||||
{"prefix not match domain in name", "*.windscribe.com", "wwindscribe.com", false},
|
||||
{"suffix", "suffix.*", "suffix.windscribe.com", true},
|
||||
{"suffix not match other", "suffix.*", "suffix1.windscribe.com", false},
|
||||
{"both", "suffix.*.windscribe.com", "suffix.anything.windscribe.com", true},
|
||||
{"both not match", "suffix.*.windscribe.com", "suffix1.suffix.windscribe.com", false},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
if got := wildcardMatches(tc.wildcard, tc.domain); got != tc.match {
|
||||
t.Errorf("unexpected result, wildcard: %s, domain: %s, want: %v, got: %v", tc.wildcard, tc.domain, tc.match, got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_canonicalName(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
domain string
|
||||
canonical string
|
||||
}{
|
||||
{"fqdn to canonical", "windscribe.com.", "windscribe.com"},
|
||||
{"already canonical", "windscribe.com", "windscribe.com"},
|
||||
{"case insensitive", "Windscribe.Com.", "windscribe.com"},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
if got := canonicalName(tc.domain); got != tc.canonical {
|
||||
t.Errorf("unexpected result, want: %s, got: %s", tc.canonical, got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_prog_upstreamFor(t *testing.T) {
|
||||
cfg := testhelper.SampleConfig(t)
|
||||
prog := &prog{cfg: cfg}
|
||||
for _, nc := range prog.cfg.Network {
|
||||
for _, cidr := range nc.Cidrs {
|
||||
_, ipNet, err := net.ParseCIDR(cidr)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
nc.IPNets = append(nc.IPNets, ipNet)
|
||||
}
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
ip string
|
||||
defaultUpstreamNum string
|
||||
lc *ctrld.ListenerConfig
|
||||
domain string
|
||||
upstreams []string
|
||||
matched bool
|
||||
testLogMsg string
|
||||
}{
|
||||
{"Policy map matches", "192.168.0.1:0", "0", prog.cfg.Listener["0"], "abc.xyz", []string{"upstream.1", "upstream.0"}, true, ""},
|
||||
{"Policy split matches", "192.168.0.1:0", "0", prog.cfg.Listener["0"], "abc.ru", []string{"upstream.1"}, true, ""},
|
||||
{"Policy map for other network matches", "192.168.1.2:0", "0", prog.cfg.Listener["0"], "abc.xyz", []string{"upstream.0"}, true, ""},
|
||||
{"No policy map for listener", "192.168.1.2:0", "1", prog.cfg.Listener["1"], "abc.ru", []string{"upstream.1"}, false, ""},
|
||||
{"unenforced loging", "192.168.1.2:0", "0", prog.cfg.Listener["0"], "abc.ru", []string{"upstream.1"}, true, "My Policy, network.1 (unenforced), *.ru -> [upstream.1]"},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
for _, network := range []string{"udp", "tcp"} {
|
||||
var (
|
||||
addr net.Addr
|
||||
err error
|
||||
)
|
||||
switch network {
|
||||
case "udp":
|
||||
addr, err = net.ResolveUDPAddr(network, tc.ip)
|
||||
case "tcp":
|
||||
addr, err = net.ResolveTCPAddr(network, tc.ip)
|
||||
}
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, addr)
|
||||
ctx := context.WithValue(context.Background(), ctrld.ReqIdCtxKey{}, requestID())
|
||||
upstreams, matched := prog.upstreamFor(ctx, tc.defaultUpstreamNum, tc.lc, addr, tc.domain)
|
||||
assert.Equal(t, tc.matched, matched)
|
||||
assert.Equal(t, tc.upstreams, upstreams)
|
||||
if tc.testLogMsg != "" {
|
||||
assert.Contains(t, logOutput.String(), tc.testLogMsg)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCache(t *testing.T) {
|
||||
cfg := testhelper.SampleConfig(t)
|
||||
prog := &prog{cfg: cfg}
|
||||
for _, nc := range prog.cfg.Network {
|
||||
for _, cidr := range nc.Cidrs {
|
||||
_, ipNet, err := net.ParseCIDR(cidr)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
nc.IPNets = append(nc.IPNets, ipNet)
|
||||
}
|
||||
}
|
||||
cacher, err := dnscache.NewLRUCache(4096)
|
||||
require.NoError(t, err)
|
||||
prog.cache = cacher
|
||||
|
||||
msg := new(dns.Msg)
|
||||
msg.SetQuestion("example.com", dns.TypeA)
|
||||
msg.MsgHdr.RecursionDesired = true
|
||||
answer1 := new(dns.Msg)
|
||||
answer1.SetRcode(msg, dns.RcodeSuccess)
|
||||
|
||||
prog.cache.Add(dnscache.NewKey(msg, "upstream.1"), dnscache.NewValue(answer1, time.Now().Add(time.Minute)))
|
||||
answer2 := new(dns.Msg)
|
||||
answer2.SetRcode(msg, dns.RcodeRefused)
|
||||
prog.cache.Add(dnscache.NewKey(msg, "upstream.0"), dnscache.NewValue(answer2, time.Now().Add(time.Minute)))
|
||||
|
||||
got1 := prog.proxy(context.Background(), []string{"upstream.1"}, nil, msg)
|
||||
got2 := prog.proxy(context.Background(), []string{"upstream.0"}, nil, msg)
|
||||
assert.NotSame(t, got1, got2)
|
||||
assert.Equal(t, answer1.Rcode, got1.Rcode)
|
||||
assert.Equal(t, answer2.Rcode, got2.Rcode)
|
||||
}
|
||||
|
||||
func Test_macFromMsg(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
mac string
|
||||
wantMac bool
|
||||
}{
|
||||
{"has mac", "4c:20:b8:ab:87:1b", true},
|
||||
{"no mac", "4c:20:b8:ab:87:1b", false},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
hw, err := net.ParseMAC(tc.mac)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
m := new(dns.Msg)
|
||||
m.SetQuestion("example.com.", dns.TypeA)
|
||||
o := &dns.OPT{Hdr: dns.RR_Header{Name: ".", Rrtype: dns.TypeOPT}}
|
||||
if tc.wantMac {
|
||||
ec1 := &dns.EDNS0_LOCAL{Code: EDNS0_OPTION_MAC, Data: hw}
|
||||
o.Option = append(o.Option, ec1)
|
||||
}
|
||||
m.Extra = append(m.Extra, o)
|
||||
got := macFromMsg(m)
|
||||
if tc.wantMac && got != tc.mac {
|
||||
t.Errorf("mismatch, want: %q, got: %q", tc.mac, got)
|
||||
}
|
||||
if !tc.wantMac && got != "" {
|
||||
t.Errorf("unexpected mac: %q", got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_remoteAddrFromMsg(t *testing.T) {
|
||||
loopbackIP := net.ParseIP("127.0.0.1")
|
||||
tests := []struct {
|
||||
name string
|
||||
addr net.Addr
|
||||
ci *ctrld.ClientInfo
|
||||
want string
|
||||
}{
|
||||
{"tcp", &net.TCPAddr{IP: loopbackIP, Port: 12345}, &ctrld.ClientInfo{IP: "192.168.1.10"}, "192.168.1.10:12345"},
|
||||
{"udp", &net.UDPAddr{IP: loopbackIP, Port: 12345}, &ctrld.ClientInfo{IP: "192.168.1.11"}, "192.168.1.11:12345"},
|
||||
{"nil client info", &net.UDPAddr{IP: loopbackIP, Port: 12345}, nil, "127.0.0.1:12345"},
|
||||
{"empty ip", &net.UDPAddr{IP: loopbackIP, Port: 12345}, &ctrld.ClientInfo{}, "127.0.0.1:12345"},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
addr := spoofRemoteAddr(tc.addr, tc.ci)
|
||||
if addr.String() != tc.want {
|
||||
t.Errorf("unexpected result, want: %q, got: %q", tc.want, addr.String())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user