mirror of
https://github.com/Control-D-Inc/ctrld.git
synced 2026-02-03 22:18:39 +00:00
- Add appVersion variable to store curVersion() result during init - Change initCLI() to return *cobra.Command - Move rootCmd creation inside initCLI() as local variable - Replace all rootCmd.Version usage with appVersion variable - Update Main() function to capture returned rootCmd from initCLI() - Remove sync.Once guard from tests and use initCLI() directly - Remove sync import from test file as it's no longer needed This refactoring improves encapsulation by eliminating global state, reduces version computation overhead, and simplifies test setup by removing the need for sync.Once guards. All tests pass and the application builds successfully.
347 lines
9.3 KiB
Go
347 lines
9.3 KiB
Go
package cli
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"net"
|
|
"net/http"
|
|
"os"
|
|
"reflect"
|
|
"sort"
|
|
"time"
|
|
|
|
"github.com/kardianos/service"
|
|
dto "github.com/prometheus/client_model/go"
|
|
|
|
"github.com/Control-D-Inc/ctrld"
|
|
"github.com/Control-D-Inc/ctrld/internal/controld"
|
|
)
|
|
|
|
const (
|
|
contentTypeJson = "application/json"
|
|
listClientsPath = "/clients"
|
|
startedPath = "/started"
|
|
reloadPath = "/reload"
|
|
deactivationPath = "/deactivation"
|
|
cdPath = "/cd"
|
|
ifacePath = "/iface"
|
|
viewLogsPath = "/log/view"
|
|
sendLogsPath = "/log/send"
|
|
)
|
|
|
|
type ifaceResponse struct {
|
|
Name string `json:"name"`
|
|
All bool `json:"all"`
|
|
OK bool `json:"ok"`
|
|
}
|
|
|
|
type controlServer struct {
|
|
server *http.Server
|
|
mux *http.ServeMux
|
|
addr string
|
|
}
|
|
|
|
func newControlServer(addr string) (*controlServer, error) {
|
|
mux := http.NewServeMux()
|
|
s := &controlServer{
|
|
server: &http.Server{Handler: mux},
|
|
mux: mux,
|
|
}
|
|
s.addr = addr
|
|
return s, nil
|
|
}
|
|
|
|
func (s *controlServer) start() error {
|
|
_ = os.Remove(s.addr)
|
|
unixListener, err := net.Listen("unix", s.addr)
|
|
if l, ok := unixListener.(*net.UnixListener); ok {
|
|
l.SetUnlinkOnClose(true)
|
|
}
|
|
if err != nil {
|
|
return err
|
|
}
|
|
go s.server.Serve(unixListener)
|
|
return nil
|
|
}
|
|
|
|
func (s *controlServer) stop() error {
|
|
_ = os.Remove(s.addr)
|
|
ctx, cancel := context.WithTimeout(context.Background(), time.Second*2)
|
|
defer cancel()
|
|
return s.server.Shutdown(ctx)
|
|
}
|
|
|
|
func (s *controlServer) register(pattern string, handler http.Handler) {
|
|
s.mux.Handle(pattern, jsonResponse(handler))
|
|
}
|
|
|
|
func (p *prog) registerControlServerHandler() {
|
|
p.cs.register(listClientsPath, http.HandlerFunc(func(w http.ResponseWriter, request *http.Request) {
|
|
p.Debug().Msg("handling list clients request")
|
|
|
|
clients := p.ciTable.ListClients()
|
|
p.Debug().Int("client_count", len(clients)).Msg("retrieved clients list")
|
|
|
|
sort.Slice(clients, func(i, j int) bool {
|
|
return clients[i].IP.Less(clients[j].IP)
|
|
})
|
|
p.Debug().Msg("sorted clients by IP address")
|
|
|
|
if p.metricsQueryStats.Load() {
|
|
p.Debug().Msg("metrics query stats enabled, collecting query counts")
|
|
|
|
for idx, client := range clients {
|
|
p.Debug().
|
|
Int("index", idx).
|
|
Str("ip", client.IP.String()).
|
|
Str("mac", client.Mac).
|
|
Str("hostname", client.Hostname).
|
|
Msg("processing client metrics")
|
|
|
|
client.IncludeQueryCount = true
|
|
dm := &dto.Metric{}
|
|
|
|
if statsClientQueriesCount.MetricVec == nil {
|
|
p.Debug().
|
|
Str("client_ip", client.IP.String()).
|
|
Msg("skipping metrics collection: MetricVec is nil")
|
|
continue
|
|
}
|
|
|
|
m, err := statsClientQueriesCount.MetricVec.GetMetricWithLabelValues(
|
|
client.IP.String(),
|
|
client.Mac,
|
|
client.Hostname,
|
|
)
|
|
if err != nil {
|
|
p.Debug().
|
|
Err(err).
|
|
Str("client_ip", client.IP.String()).
|
|
Str("mac", client.Mac).
|
|
Str("hostname", client.Hostname).
|
|
Msg("failed to get metrics for client")
|
|
continue
|
|
}
|
|
|
|
if err := m.Write(dm); err == nil && dm.Counter != nil {
|
|
client.QueryCount = int64(dm.Counter.GetValue())
|
|
p.Debug().
|
|
Str("client_ip", client.IP.String()).
|
|
Int64("query_count", client.QueryCount).
|
|
Msg("successfully collected query count")
|
|
} else if err != nil {
|
|
p.Debug().
|
|
Err(err).
|
|
Str("client_ip", client.IP.String()).
|
|
Msg("failed to write metric")
|
|
}
|
|
}
|
|
} else {
|
|
p.Debug().Msg("metrics query stats disabled, skipping query counts")
|
|
}
|
|
|
|
if err := json.NewEncoder(w).Encode(&clients); err != nil {
|
|
p.Error().
|
|
Err(err).
|
|
Int("client_count", len(clients)).
|
|
Msg("failed to encode clients response")
|
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
p.Debug().
|
|
Int("client_count", len(clients)).
|
|
Msg("successfully sent clients list response")
|
|
}))
|
|
p.cs.register(startedPath, http.HandlerFunc(func(w http.ResponseWriter, request *http.Request) {
|
|
select {
|
|
case <-p.onStartedDone:
|
|
w.WriteHeader(http.StatusOK)
|
|
case <-time.After(10 * time.Second):
|
|
w.WriteHeader(http.StatusRequestTimeout)
|
|
}
|
|
}))
|
|
p.cs.register(reloadPath, http.HandlerFunc(func(w http.ResponseWriter, request *http.Request) {
|
|
listeners := make(map[string]*ctrld.ListenerConfig)
|
|
p.mu.Lock()
|
|
for k, v := range p.cfg.Listener {
|
|
listeners[k] = &ctrld.ListenerConfig{
|
|
IP: v.IP,
|
|
Port: v.Port,
|
|
}
|
|
}
|
|
oldSvc := p.cfg.Service
|
|
p.mu.Unlock()
|
|
if err := p.sendReloadSignal(); err != nil {
|
|
p.Error().Err(err).Msg("could not send reload signal")
|
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
|
return
|
|
}
|
|
select {
|
|
case <-p.reloadDoneCh:
|
|
case <-time.After(5 * time.Second):
|
|
http.Error(w, "timeout waiting for ctrld reload", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
p.mu.Lock()
|
|
defer p.mu.Unlock()
|
|
|
|
// Checking for cases that we could not do a reload.
|
|
|
|
// 1. Listener config ip or port changes.
|
|
for k, v := range p.cfg.Listener {
|
|
l := listeners[k]
|
|
if l == nil || l.IP != v.IP || l.Port != v.Port {
|
|
w.WriteHeader(http.StatusCreated)
|
|
return
|
|
}
|
|
}
|
|
|
|
// 2. Service config changes.
|
|
if !reflect.DeepEqual(oldSvc, p.cfg.Service) {
|
|
w.WriteHeader(http.StatusCreated)
|
|
return
|
|
}
|
|
|
|
// Otherwise, reload is done.
|
|
w.WriteHeader(http.StatusOK)
|
|
}))
|
|
p.cs.register(deactivationPath, http.HandlerFunc(func(w http.ResponseWriter, request *http.Request) {
|
|
// Non-cd mode always allowing deactivation.
|
|
if cdUID == "" {
|
|
w.WriteHeader(http.StatusOK)
|
|
return
|
|
}
|
|
|
|
loggerCtx := ctrld.LoggerCtx(context.Background(), p.logger.Load())
|
|
// Re-fetch pin code from API.
|
|
if rc, err := controld.FetchResolverConfig(loggerCtx, cdUID, appVersion, cdDev); rc != nil {
|
|
if rc.DeactivationPin != nil {
|
|
cdDeactivationPin.Store(*rc.DeactivationPin)
|
|
} else {
|
|
cdDeactivationPin.Store(defaultDeactivationPin)
|
|
}
|
|
} else {
|
|
p.Warn().Err(err).Msg("could not re-fetch deactivation pin code")
|
|
}
|
|
|
|
// If pin code not set, allowing deactivation.
|
|
if !deactivationPinSet() {
|
|
w.WriteHeader(http.StatusOK)
|
|
return
|
|
}
|
|
|
|
var req deactivationRequest
|
|
if err := json.NewDecoder(request.Body).Decode(&req); err != nil {
|
|
w.WriteHeader(http.StatusPreconditionFailed)
|
|
p.Error().Err(err).Msg("invalid deactivation request")
|
|
return
|
|
}
|
|
|
|
code := http.StatusForbidden
|
|
switch req.Pin {
|
|
case cdDeactivationPin.Load():
|
|
code = http.StatusOK
|
|
select {
|
|
case p.pinCodeValidCh <- struct{}{}:
|
|
default:
|
|
}
|
|
case defaultDeactivationPin:
|
|
// If the pin code was set, but users do not provide --pin, return proper code to client.
|
|
code = http.StatusBadRequest
|
|
}
|
|
w.WriteHeader(code)
|
|
}))
|
|
p.cs.register(cdPath, http.HandlerFunc(func(w http.ResponseWriter, request *http.Request) {
|
|
if cdUID != "" {
|
|
w.WriteHeader(http.StatusOK)
|
|
w.Write([]byte(cdUID))
|
|
return
|
|
}
|
|
w.WriteHeader(http.StatusBadRequest)
|
|
}))
|
|
p.cs.register(ifacePath, http.HandlerFunc(func(w http.ResponseWriter, request *http.Request) {
|
|
res := &ifaceResponse{Name: iface}
|
|
// p.setDNS is only called when running as a service
|
|
if !service.Interactive() {
|
|
<-p.csSetDnsDone
|
|
if p.csSetDnsOk {
|
|
res.Name = p.runningIface
|
|
res.All = p.requiredMultiNICsConfig
|
|
res.OK = true
|
|
}
|
|
}
|
|
if err := json.NewEncoder(w).Encode(res); err != nil {
|
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
|
http.Error(w, fmt.Sprintf("could not marshal iface data: %v", err), http.StatusInternalServerError)
|
|
return
|
|
}
|
|
}))
|
|
p.cs.register(viewLogsPath, http.HandlerFunc(func(w http.ResponseWriter, request *http.Request) {
|
|
lr, err := p.logReader()
|
|
if err != nil {
|
|
http.Error(w, err.Error(), http.StatusBadRequest)
|
|
return
|
|
}
|
|
defer lr.r.Close()
|
|
if lr.size == 0 {
|
|
w.WriteHeader(http.StatusMovedPermanently)
|
|
return
|
|
}
|
|
data, err := io.ReadAll(lr.r)
|
|
if err != nil {
|
|
http.Error(w, fmt.Sprintf("could not read log: %v", err), http.StatusInternalServerError)
|
|
return
|
|
}
|
|
if err := json.NewEncoder(w).Encode(&logViewResponse{Data: string(data)}); err != nil {
|
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
|
http.Error(w, fmt.Sprintf("could not marshal log data: %v", err), http.StatusInternalServerError)
|
|
return
|
|
}
|
|
}))
|
|
p.cs.register(sendLogsPath, http.HandlerFunc(func(w http.ResponseWriter, request *http.Request) {
|
|
if time.Since(p.internalLogSent) < logWriterSentInterval {
|
|
w.WriteHeader(http.StatusServiceUnavailable)
|
|
return
|
|
}
|
|
r, err := p.logReader()
|
|
if err != nil {
|
|
http.Error(w, err.Error(), http.StatusBadRequest)
|
|
return
|
|
}
|
|
if r.size == 0 {
|
|
w.WriteHeader(http.StatusMovedPermanently)
|
|
return
|
|
}
|
|
req := &controld.LogsRequest{
|
|
UID: cdUID,
|
|
Data: r.r,
|
|
}
|
|
p.Debug().Msg("sending log file to ControlD server")
|
|
resp := logSentResponse{Size: r.size}
|
|
loggerCtx := ctrld.LoggerCtx(context.Background(), p.logger.Load())
|
|
if err := controld.SendLogs(loggerCtx, req, cdDev); err != nil {
|
|
p.Error().Msgf("could not send log file to ControlD server: %v", err)
|
|
resp.Error = err.Error()
|
|
w.WriteHeader(http.StatusInternalServerError)
|
|
} else {
|
|
p.Debug().Msg("sending log file successfully")
|
|
w.WriteHeader(http.StatusOK)
|
|
}
|
|
if err := json.NewEncoder(w).Encode(&resp); err != nil {
|
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
|
}
|
|
p.internalLogSent = time.Now()
|
|
}))
|
|
}
|
|
|
|
func jsonResponse(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.Header().Set("Content-Type", "application/json")
|
|
next.ServeHTTP(w, r)
|
|
})
|
|
}
|