diff --git a/internal/app/app.go b/internal/app/app.go index fef5010b..ca129df4 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -541,6 +541,8 @@ func (a *App) RunWithContext(ctx context.Context) error { } srv := &http.Server{Addr: addr, Handler: a.router} + var mainMux *mainServerMux + httpRedirect := config.ServerHTTPRedirectEnabled(&a.config.Server) if tlsMode != mainTLSOff { srv.TLSConfig = tlsConf if err := http2.ConfigureServer(srv, &http2.Server{}); err != nil { @@ -557,6 +559,9 @@ func (a *App) RunWithContext(ctx context.Context) error { zap.String("address", addr), ) } + if httpRedirect { + a.logger.Info("已启用 HTTP→HTTPS 自动跳转(同端口嗅探分流)", zap.String("address", addr)) + } } else { a.logger.Info("启动 HTTP 主服务", zap.String("address", addr)) } @@ -566,7 +571,11 @@ func (a *App) RunWithContext(ctx context.Context) error { <-ctx.Done() shutdownCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() - if err := srv.Shutdown(shutdownCtx); err != nil { + if mainMux != nil { + if err := mainMux.Shutdown(shutdownCtx); err != nil { + a.logger.Error("HTTP/HTTPS 分流服务器关闭失败", zap.Error(err)) + } + } else if err := srv.Shutdown(shutdownCtx); err != nil { a.logger.Error("HTTP服务器关闭失败", zap.Error(err)) } if mcpServer != nil { @@ -577,12 +586,26 @@ func (a *App) RunWithContext(ctx context.Context) error { }() var err error - switch tlsMode { - case mainTLSOff: + switch { + case tlsMode != mainTLSOff && httpRedirect: + var tlsConfReady *tls.Config + tlsConfReady, err = ensureMainTLSConfigCerts(tlsMode, tlsConf, certFile, keyFile) + if err != nil { + return fmt.Errorf("加载 TLS 证书: %w", err) + } + srv.TLSConfig = tlsConfReady + var ln net.Listener + ln, err = net.Listen("tcp", addr) + if err != nil { + return err + } + mainMux = newMainServerMux(ln, srv, portFromListenAddr(addr), a.logger.Logger) + err = mainMux.Serve() + case tlsMode == mainTLSOff: err = srv.ListenAndServe() - case mainTLSFromFiles: + case tlsMode == mainTLSFromFiles: err = srv.ListenAndServeTLS(certFile, keyFile) - case mainTLSInMemorySelfSigned: + case tlsMode == mainTLSInMemorySelfSigned: var ln net.Listener ln, err = tls.Listen("tcp", addr, srv.TLSConfig) if err == nil { diff --git a/internal/app/main_server_http_redirect.go b/internal/app/main_server_http_redirect.go new file mode 100644 index 00000000..c4817074 --- /dev/null +++ b/internal/app/main_server_http_redirect.go @@ -0,0 +1,196 @@ +package app + +import ( + "bufio" + "context" + "crypto/tls" + "errors" + "fmt" + "net" + "net/http" + "strconv" + "sync" + "time" + + "go.uber.org/zap" +) + +// peekedConn 在已预读首字节后仍将连接交给 net/http 或 crypto/tls。 +type peekedConn struct { + net.Conn + r *bufio.Reader +} + +func (c *peekedConn) Read(p []byte) (int, error) { + return c.r.Read(p) +} + +// oneConnListener 供 http.Server.Serve 处理单条 TCP 连接(含 keep-alive)。 +type oneConnListener struct { + conn net.Conn + addr net.Addr + once sync.Once +} + +func (l *oneConnListener) Accept() (net.Conn, error) { + var c net.Conn + l.once.Do(func() { + c = l.conn + l.conn = nil + }) + if c == nil { + return nil, net.ErrClosed + } + return c, nil +} + +func (l *oneConnListener) Close() error { return nil } +func (l *oneConnListener) Addr() net.Addr { return l.addr } + +func isTLSHandshakeRecord(b byte) bool { + return b == 0x16 +} + +func newHTTPToHTTPSRedirectHandler(httpsPort int) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + host := r.Host + if h, _, err := net.SplitHostPort(host); err == nil { + host = h + } + var target string + if httpsPort == 443 { + target = fmt.Sprintf("https://%s%s", host, r.URL.RequestURI()) + } else { + target = fmt.Sprintf("https://%s:%d%s", host, httpsPort, r.URL.RequestURI()) + } + http.Redirect(w, r, target, http.StatusPermanentRedirect) + }) +} + +func portFromListenAddr(addr string) int { + _, portStr, err := net.SplitHostPort(addr) + if err != nil { + return 443 + } + p, err := strconv.Atoi(portStr) + if err != nil || p <= 0 { + return 443 + } + return p +} + +func ensureMainTLSConfigCerts(mode mainTLSMode, tlsConf *tls.Config, certFile, keyFile string) (*tls.Config, error) { + if mode != mainTLSFromFiles { + return tlsConf, nil + } + if tlsConf == nil { + tlsConf = &tls.Config{MinVersion: tls.VersionTLS12} + } + if len(tlsConf.Certificates) > 0 { + return tlsConf, nil + } + cert, err := tls.LoadX509KeyPair(certFile, keyFile) + if err != nil { + return nil, err + } + tlsConf.Certificates = []tls.Certificate{cert} + return tlsConf, nil +} + +type mainServerMux struct { + ln net.Listener + httpsSrv *http.Server + redirectSrv *http.Server + logger *zap.Logger +} + +func newMainServerMux(ln net.Listener, httpsSrv *http.Server, httpsPort int, logger *zap.Logger) *mainServerMux { + return &mainServerMux{ + ln: ln, + httpsSrv: httpsSrv, + redirectSrv: &http.Server{Handler: newHTTPToHTTPSRedirectHandler(httpsPort), ReadHeaderTimeout: 10 * time.Second}, + logger: logger, + } +} + +func (m *mainServerMux) Serve() error { + for { + conn, err := m.ln.Accept() + if err != nil { + if errors.Is(err, net.ErrClosed) { + return http.ErrServerClosed + } + return err + } + go m.handleConn(conn) + } +} + +func (m *mainServerMux) handleConn(raw net.Conn) { + if err := raw.SetReadDeadline(time.Now().Add(10 * time.Second)); err != nil { + _ = raw.Close() + return + } + br := bufio.NewReader(raw) + b, err := br.Peek(1) + if err != nil { + _ = raw.Close() + return + } + _ = raw.SetReadDeadline(time.Time{}) + + pc := &peekedConn{Conn: raw, r: br} + ocl := &oneConnListener{conn: pc, addr: raw.LocalAddr()} + + if isTLSHandshakeRecord(b[0]) { + m.serveHTTPS(pc, raw.LocalAddr()) + return + } + if err := m.redirectSrv.Serve(ocl); err != nil && !errors.Is(err, net.ErrClosed) && !errors.Is(err, http.ErrServerClosed) { + m.logger.Debug("HTTP 重定向连接处理结束", zap.Error(err)) + } +} + +// serveHTTPS 在已嗅探为 TLS 的连接上完成握手,再按 ALPN 走 HTTP/2 或 HTTP/1.1。 +// 不能对同一 http.Server 并发调用 Serve(TLSConfig!=nil),否则握手/ALPN 会异常(浏览器 ERR_SSL_PROTOCOL_ERROR)。 +func (m *mainServerMux) serveHTTPS(pc *peekedConn, localAddr net.Addr) { + tlsConn := tls.Server(pc, m.httpsSrv.TLSConfig) + handCtx, cancel := context.WithTimeout(context.Background(), 15*time.Second) + defer cancel() + if err := tlsConn.HandshakeContext(handCtx); err != nil { + m.logger.Debug("TLS 握手失败", zap.Error(err)) + _ = pc.Close() + return + } + + srv := m.httpsSrv + if srv.TLSNextProto != nil { + proto := tlsConn.ConnectionState().NegotiatedProtocol + if fn := srv.TLSNextProto[proto]; fn != nil { + fn(srv, tlsConn, srv.Handler) + return + } + } + + plain := *srv + plain.TLSConfig = nil + ocl := &oneConnListener{conn: tlsConn, addr: localAddr} + if err := plain.Serve(ocl); err != nil && !errors.Is(err, net.ErrClosed) && !errors.Is(err, http.ErrServerClosed) { + m.logger.Debug("HTTPS 连接处理结束", zap.Error(err)) + } +} + +func (m *mainServerMux) Shutdown(ctx context.Context) error { + _ = m.ln.Close() + var err1, err2 error + if m.httpsSrv != nil { + err1 = m.httpsSrv.Shutdown(ctx) + } + if m.redirectSrv != nil { + err2 = m.redirectSrv.Shutdown(ctx) + } + if err1 != nil { + return err1 + } + return err2 +} diff --git a/internal/app/main_server_http_redirect_test.go b/internal/app/main_server_http_redirect_test.go new file mode 100644 index 00000000..99037f29 --- /dev/null +++ b/internal/app/main_server_http_redirect_test.go @@ -0,0 +1,150 @@ +package app + +import ( + "crypto/tls" + "io" + "net" + "net/http" + "net/http/httptest" + "strconv" + "testing" + + "cyberstrike-ai/internal/config" + + "golang.org/x/net/http2" +) + +func TestNewHTTPToHTTPSRedirectHandler(t *testing.T) { + t.Parallel() + tests := []struct { + name string + httpsPort int + host string + uri string + wantTarget string + }{ + { + name: "non standard port", + httpsPort: 8080, + host: "127.0.0.1:8080", + uri: "/login?next=/", + wantTarget: "https://127.0.0.1:8080/login?next=/", + }, + { + name: "standard port", + httpsPort: 443, + host: "example.com:80", + uri: "/", + wantTarget: "https://example.com/", + }, + } + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + h := newHTTPToHTTPSRedirectHandler(tt.httpsPort) + req := httptest.NewRequest(http.MethodGet, "http://"+tt.host+tt.uri, nil) + req.Host = tt.host + rec := httptest.NewRecorder() + h.ServeHTTP(rec, req) + if rec.Code != http.StatusPermanentRedirect { + t.Fatalf("status = %d, want %d", rec.Code, http.StatusPermanentRedirect) + } + if got := rec.Header().Get("Location"); got != tt.wantTarget { + t.Fatalf("Location = %q, want %q", got, tt.wantTarget) + } + }) + } +} + +func TestIsTLSHandshakeRecord(t *testing.T) { + t.Parallel() + if !isTLSHandshakeRecord(0x16) { + t.Fatal("expected TLS handshake record") + } + if isTLSHandshakeRecord('G') { + t.Fatal("GET should not be TLS") + } +} + +func TestServerHTTPRedirectEnabled(t *testing.T) { + t.Parallel() + disabled := false + enabled := true + if config.ServerHTTPRedirectEnabled(nil) { + t.Fatal("nil config should disable redirect") + } + if !config.ServerHTTPRedirectEnabled(&config.ServerConfig{TLSEnabled: true}) { + t.Fatal("HTTPS without explicit flag should enable redirect") + } + if config.ServerHTTPRedirectEnabled(&config.ServerConfig{TLSEnabled: true, TLSHTTPRedirect: &disabled}) { + t.Fatal("explicit false should disable redirect") + } + if !config.ServerHTTPRedirectEnabled(&config.ServerConfig{TLSEnabled: true, TLSHTTPRedirect: &enabled}) { + t.Fatal("explicit true should enable redirect") + } + if config.ServerHTTPRedirectEnabled(&config.ServerConfig{}) { + t.Fatal("plain HTTP should not redirect") + } +} + +func TestMainServerMuxHTTPRedirectAndHTTPS(t *testing.T) { + cert, err := generateMainServerSelfSignedCert() + if err != nil { + t.Fatalf("generate cert: %v", err) + } + handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + _, _ = io.WriteString(w, "ok") + }) + srv := &http.Server{Handler: handler, TLSConfig: &tls.Config{ + MinVersion: tls.VersionTLS12, + Certificates: []tls.Certificate{cert}, + }} + if err := http2.ConfigureServer(srv, &http2.Server{}); err != nil { + t.Fatalf("configure http2: %v", err) + } + + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("listen: %v", err) + } + defer ln.Close() + + mux := newMainServerMux(ln, srv, portFromListenAddr(ln.Addr().String()), nil) + go func() { _ = mux.Serve() }() + + client := &http.Client{ + Transport: &http.Transport{ + TLSClientConfig: &tls.Config{InsecureSkipVerify: true, MinVersion: tls.VersionTLS12}, + }, + CheckRedirect: func(_ *http.Request, _ []*http.Request) error { + return http.ErrUseLastResponse + }, + } + addr := ln.Addr().String() + + httpResp, err := client.Get("http://" + addr + "/") + if err != nil { + t.Fatalf("http get: %v", err) + } + _ = httpResp.Body.Close() + if httpResp.StatusCode != http.StatusPermanentRedirect { + t.Fatalf("http status = %d, want %d", httpResp.StatusCode, http.StatusPermanentRedirect) + } + if got := httpResp.Header.Get("Location"); got != "https://127.0.0.1:"+strconv.Itoa(portFromListenAddr(addr))+"/" { + t.Fatalf("Location = %q", got) + } + + httpsResp, err := client.Get("https://" + addr + "/") + if err != nil { + t.Fatalf("https get: %v", err) + } + defer httpsResp.Body.Close() + if httpsResp.StatusCode != http.StatusOK { + t.Fatalf("https status = %d, want %d", httpsResp.StatusCode, http.StatusOK) + } + body, _ := io.ReadAll(httpsResp.Body) + if string(body) != "ok" { + t.Fatalf("body = %q, want ok", body) + } +}