Files
ctrld/cmd/cli/control_server.go
Alex Paguis 7833132917 Don't automatically restore saved DNS settings when switching networks
smol tweaks to nameserver test queries

fix restoreDNS errors

add some debugging information

fix wront type in log msg

set send logs command timeout to 5 mins

when the runningIface is no longer up, attempt to find a new interface

prefer default route, ignore non physical interfaces

prefer default route, ignore non physical interfaces

add max context timeout on performLeakingQuery with more debug logs
2025-01-20 14:59:31 +07:00

291 lines
7.8 KiB
Go

package cli
import (
"context"
"encoding/json"
"fmt"
"io"
"net"
"net/http"
"os"
"reflect"
"sort"
"time"
"github.com/kardianos/service"
dto "github.com/prometheus/client_model/go"
"github.com/Control-D-Inc/ctrld"
"github.com/Control-D-Inc/ctrld/internal/controld"
)
const (
contentTypeJson = "application/json"
listClientsPath = "/clients"
startedPath = "/started"
reloadPath = "/reload"
deactivationPath = "/deactivation"
cdPath = "/cd"
ifacePath = "/iface"
viewLogsPath = "/log/view"
sendLogsPath = "/log/send"
)
type ifaceResponse struct {
Name string `json:"name"`
All bool `json:"all"`
}
type controlServer struct {
server *http.Server
mux *http.ServeMux
addr string
}
func newControlServer(addr string) (*controlServer, error) {
mux := http.NewServeMux()
s := &controlServer{
server: &http.Server{Handler: mux},
mux: mux,
}
s.addr = addr
return s, nil
}
func (s *controlServer) start() error {
_ = os.Remove(s.addr)
unixListener, err := net.Listen("unix", s.addr)
if l, ok := unixListener.(*net.UnixListener); ok {
l.SetUnlinkOnClose(true)
}
if err != nil {
return err
}
go s.server.Serve(unixListener)
return nil
}
func (s *controlServer) stop() error {
_ = os.Remove(s.addr)
ctx, cancel := context.WithTimeout(context.Background(), time.Second*2)
defer cancel()
return s.server.Shutdown(ctx)
}
func (s *controlServer) register(pattern string, handler http.Handler) {
s.mux.Handle(pattern, jsonResponse(handler))
}
func (p *prog) registerControlServerHandler() {
p.cs.register(listClientsPath, http.HandlerFunc(func(w http.ResponseWriter, request *http.Request) {
clients := p.ciTable.ListClients()
sort.Slice(clients, func(i, j int) bool {
return clients[i].IP.Less(clients[j].IP)
})
if p.metricsQueryStats.Load() {
for _, client := range clients {
client.IncludeQueryCount = true
dm := &dto.Metric{}
m, err := statsClientQueriesCount.MetricVec.GetMetricWithLabelValues(
client.IP.String(),
client.Mac,
client.Hostname,
)
if err != nil {
mainLog.Load().Debug().Err(err).Msgf("could not get metrics for client: %v", client)
continue
}
if err := m.Write(dm); err == nil {
client.QueryCount = int64(dm.Counter.GetValue())
}
}
}
if err := json.NewEncoder(w).Encode(&clients); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
}))
p.cs.register(startedPath, http.HandlerFunc(func(w http.ResponseWriter, request *http.Request) {
select {
case <-p.onStartedDone:
w.WriteHeader(http.StatusOK)
case <-time.After(10 * time.Second):
w.WriteHeader(http.StatusRequestTimeout)
}
}))
p.cs.register(reloadPath, http.HandlerFunc(func(w http.ResponseWriter, request *http.Request) {
listeners := make(map[string]*ctrld.ListenerConfig)
p.mu.Lock()
for k, v := range p.cfg.Listener {
listeners[k] = &ctrld.ListenerConfig{
IP: v.IP,
Port: v.Port,
}
}
oldSvc := p.cfg.Service
p.mu.Unlock()
if err := p.sendReloadSignal(); err != nil {
mainLog.Load().Err(err).Msg("could not send reload signal")
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
select {
case <-p.reloadDoneCh:
case <-time.After(5 * time.Second):
http.Error(w, "timeout waiting for ctrld reload", http.StatusInternalServerError)
return
}
p.mu.Lock()
defer p.mu.Unlock()
// Checking for cases that we could not do a reload.
// 1. Listener config ip or port changes.
for k, v := range p.cfg.Listener {
l := listeners[k]
if l == nil || l.IP != v.IP || l.Port != v.Port {
w.WriteHeader(http.StatusCreated)
return
}
}
// 2. Service config changes.
if !reflect.DeepEqual(oldSvc, p.cfg.Service) {
w.WriteHeader(http.StatusCreated)
return
}
// Otherwise, reload is done.
w.WriteHeader(http.StatusOK)
}))
p.cs.register(deactivationPath, http.HandlerFunc(func(w http.ResponseWriter, request *http.Request) {
// Non-cd mode always allowing deactivation.
if cdUID == "" {
w.WriteHeader(http.StatusOK)
return
}
// Re-fetch pin code from API.
if rc, err := controld.FetchResolverConfig(cdUID, rootCmd.Version, cdDev); rc != nil {
if rc.DeactivationPin != nil {
cdDeactivationPin.Store(*rc.DeactivationPin)
} else {
cdDeactivationPin.Store(defaultDeactivationPin)
}
} else {
mainLog.Load().Warn().Err(err).Msg("could not re-fetch deactivation pin code")
}
// If pin code not set, allowing deactivation.
if deactivationPinNotSet() {
w.WriteHeader(http.StatusOK)
return
}
var req deactivationRequest
if err := json.NewDecoder(request.Body).Decode(&req); err != nil {
w.WriteHeader(http.StatusPreconditionFailed)
mainLog.Load().Err(err).Msg("invalid deactivation request")
return
}
code := http.StatusForbidden
switch req.Pin {
case cdDeactivationPin.Load():
code = http.StatusOK
case defaultDeactivationPin:
// If the pin code was set, but users do not provide --pin, return proper code to client.
code = http.StatusBadRequest
}
w.WriteHeader(code)
}))
p.cs.register(cdPath, http.HandlerFunc(func(w http.ResponseWriter, request *http.Request) {
if cdUID != "" {
w.WriteHeader(http.StatusOK)
w.Write([]byte(cdUID))
return
}
w.WriteHeader(http.StatusBadRequest)
}))
p.cs.register(ifacePath, http.HandlerFunc(func(w http.ResponseWriter, request *http.Request) {
res := &ifaceResponse{Name: iface}
// p.setDNS is only called when running as a service
if !service.Interactive() {
<-p.csSetDnsDone
if p.csSetDnsOk {
res.Name = p.runningIface
res.All = p.requiredMultiNICsConfig
}
}
if err := json.NewEncoder(w).Encode(res); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
http.Error(w, fmt.Sprintf("could not marshal iface data: %v", err), http.StatusInternalServerError)
return
}
}))
p.cs.register(viewLogsPath, http.HandlerFunc(func(w http.ResponseWriter, request *http.Request) {
lr, err := p.logReader()
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
defer lr.r.Close()
if lr.size == 0 {
w.WriteHeader(http.StatusMovedPermanently)
return
}
data, err := io.ReadAll(lr.r)
if err != nil {
http.Error(w, fmt.Sprintf("could not read log: %v", err), http.StatusInternalServerError)
return
}
if err := json.NewEncoder(w).Encode(&logViewResponse{Data: string(data)}); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
http.Error(w, fmt.Sprintf("could not marshal log data: %v", err), http.StatusInternalServerError)
return
}
}))
p.cs.register(sendLogsPath, http.HandlerFunc(func(w http.ResponseWriter, request *http.Request) {
if time.Since(p.internalLogSent) < logSentInterval {
w.WriteHeader(http.StatusServiceUnavailable)
return
}
r, err := p.logReader()
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
if r.size == 0 {
w.WriteHeader(http.StatusMovedPermanently)
return
}
req := &controld.LogsRequest{
UID: cdUID,
Data: r.r,
}
mainLog.Load().Debug().Msg("sending log file to ControlD server")
resp := logSentResponse{Size: r.size}
if err := controld.SendLogs(req, cdDev); err != nil {
mainLog.Load().Error().Msgf("could not send log file to ControlD server: %v", err)
resp.Error = err.Error()
w.WriteHeader(http.StatusInternalServerError)
} else {
mainLog.Load().Debug().Msg("sending log file successfully")
w.WriteHeader(http.StatusOK)
}
if err := json.NewEncoder(w).Encode(&resp); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
}
p.internalLogSent = time.Now()
}))
}
func jsonResponse(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
next.ServeHTTP(w, r)
})
}