diff --git a/cmd/ctrld/cli.go b/cmd/ctrld/cli.go index acbdb2d..bb8a8da 100644 --- a/cmd/ctrld/cli.go +++ b/cmd/ctrld/cli.go @@ -134,6 +134,11 @@ func initCLI() { stopCh: stopCh, cfg: &cfg, } + if homedir == "" { + if dir, err := userHomeDir(); err == nil { + homedir = dir + } + } sockPath := filepath.Join(homedir, ctrldLogUnixSock) if addr, err := net.ResolveUnixAddr("unix", sockPath); err == nil { if conn, err := net.Dial(addr.Network(), addr.String()); err == nil { @@ -186,6 +191,11 @@ func initCLI() { } p.router = router.New(&cfg) + cs, err := newControlServer(filepath.Join(homedir, ctrldControlUnixSock)) + if err != nil { + mainLog.Warn().Err(err).Msg("could not create control server") + } + p.cs = cs // Processing --cd flag require connecting to ControlD API, which needs valid // time for validating server certificate. Some routers need NTP synchronization diff --git a/cmd/ctrld/control_client.go b/cmd/ctrld/control_client.go new file mode 100644 index 0000000..0a94c99 --- /dev/null +++ b/cmd/ctrld/control_client.go @@ -0,0 +1,29 @@ +package main + +import ( + "context" + "io" + "net" + "net/http" + "time" +) + +type controlClient struct { + c *http.Client +} + +func newControlClient(addr string) *controlClient { + return &controlClient{c: &http.Client{ + Transport: &http.Transport{ + DialContext: func(ctx context.Context, _, _ string) (net.Conn, error) { + d := net.Dialer{} + return d.DialContext(ctx, "unix", addr) + }, + }, + Timeout: time.Second * 5, + }} +} + +func (c *controlClient) post(path string, data io.Reader) (*http.Response, error) { + return c.c.Post("http://unix"+path, contentTypeJson, data) +} diff --git a/cmd/ctrld/control_server.go b/cmd/ctrld/control_server.go new file mode 100644 index 0000000..437e4a8 --- /dev/null +++ b/cmd/ctrld/control_server.go @@ -0,0 +1,59 @@ +package main + +import ( + "context" + "net" + "net/http" + "os" + "time" +) + +const contentTypeJson = "application/json" + +type controlServer struct { + server *http.Server + mux *http.ServeMux + addr string +} + +func newControlServer(addr string) (*controlServer, error) { + mux := http.NewServeMux() + s := &controlServer{ + server: &http.Server{Handler: mux}, + mux: mux, + } + s.addr = addr + return s, nil +} + +func (s *controlServer) start() error { + _ = os.Remove(s.addr) + unixListener, err := net.Listen("unix", s.addr) + if err != nil { + return err + } + go s.server.Serve(unixListener) + return nil +} + +func (s *controlServer) stop() error { + _ = os.Remove(s.addr) + ctx, cancel := context.WithTimeout(context.Background(), time.Second*2) + defer cancel() + return s.server.Shutdown(ctx) +} + +func (s *controlServer) register(pattern string, handler http.Handler) { + s.mux.Handle(pattern, jsonResponse(handler)) +} + +func (p *prog) registerControlServerHandler() { + // TODO: register handler here. +} + +func jsonResponse(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + next.ServeHTTP(w, r) + }) +} diff --git a/cmd/ctrld/control_server_test.go b/cmd/ctrld/control_server_test.go new file mode 100644 index 0000000..2bcd64a --- /dev/null +++ b/cmd/ctrld/control_server_test.go @@ -0,0 +1,54 @@ +package main + +import ( + "bytes" + "io" + "net/http" + "os" + "testing" +) + +func TestControlServer(t *testing.T) { + f, err := os.CreateTemp("", "") + if err != nil { + t.Fatal(err) + } + defer os.Remove(f.Name()) + f.Close() + + s, err := newControlServer(f.Name()) + if err != nil { + t.Fatal(err) + } + pattern := "/ping" + respBody := []byte("pong") + s.register(pattern, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = w.Write(respBody) + })) + if err := s.start(); err != nil { + t.Fatal(err) + } + + c := newControlClient(f.Name()) + resp, err := c.post(pattern, nil) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + if resp.StatusCode != 200 { + t.Fatalf("unepxected response code: %d", resp.StatusCode) + } + if ct := resp.Header.Get("content-type"); ct != contentTypeJson { + t.Fatalf("unexpected content type: %s", ct) + } + buf, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(buf, respBody) { + t.Errorf("unexpected response body, want: %q, got: %q", string(respBody), string(buf)) + } + if err := s.stop(); err != nil { + t.Fatal(err) + } +} diff --git a/cmd/ctrld/prog.go b/cmd/ctrld/prog.go index 24fd9d9..3298cd4 100644 --- a/cmd/ctrld/prog.go +++ b/cmd/ctrld/prog.go @@ -22,8 +22,9 @@ import ( ) const ( - defaultSemaphoreCap = 256 - ctrldLogUnixSock = "ctrld_start.sock" + defaultSemaphoreCap = 256 + ctrldLogUnixSock = "ctrld_start.sock" + ctrldControlUnixSock = "ctrld_control.sock" ) var logf = func(format string, args ...any) { @@ -45,6 +46,7 @@ type prog struct { waitCh chan struct{} stopCh chan struct{} logConn net.Conn + cs *controlServer cfg *ctrld.Config cache dnscache.Cacher @@ -183,6 +185,12 @@ func (p *prog) run() { if p.logConn != nil { _ = p.logConn.Close() } + if p.cs != nil { + p.registerControlServerHandler() + if err := p.cs.start(); err != nil { + mainLog.Warn().Err(err).Msg("could not start control server") + } + } wg.Wait() }