diff --git a/cmd/cli/cli.go b/cmd/cli/cli.go index 9ed3602..7aae8b0 100644 --- a/cmd/cli/cli.go +++ b/cmd/cli/cli.go @@ -48,7 +48,11 @@ import ( // selfCheckInternalTestDomain is used for testing ctrld self response to clients. const selfCheckInternalTestDomain = "ctrld" + loopTestDomain -const windowsForwardersFilename = ".forwarders.txt" +const ( + windowsForwardersFilename = ".forwarders.txt" + oldBinSuffix = "_previous" + oldLogSuffix = ".1" +) var ( version = "dev" @@ -605,7 +609,13 @@ NOTE: Uninstalling will set DNS to values provided by DHCP.`, // Config file. files = append(files, v.ConfigFileUsed()) // Log file. - files = append(files, cfg.Service.LogPath) + logFile := normalizeLogFilePath(cfg.Service.LogPath) + files = append(files, logFile) + // Backup log file. + oldLogFile := logFile + oldLogSuffix + if _, err := os.Stat(oldLogFile); err == nil { + files = append(files, oldLogFile) + } // Socket files. if dir, _ := socketDir(); dir != "" { files = append(files, filepath.Join(dir, ctrldControlUnixSock)) @@ -624,11 +634,15 @@ NOTE: Uninstalling will set DNS to values provided by DHCP.`, files = append(files, absHomeDir(windowsForwardersFilename)) } // Binary itself. - bin, _ := os.Executable() if bin != "" && supportedSelfDelete { files = append(files, bin) } + // Backup file after upgrading. + oldBin := bin + oldBinSuffix + if _, err := os.Stat(oldBin); err == nil { + files = append(files, oldBin) + } for _, file := range files { if file == "" { continue @@ -922,7 +936,7 @@ NOTE: Uninstalling will set DNS to values provided by DHCP.`, if _, err := s.Status(); errors.Is(err, service.ErrNotInstalled) { svcInstalled = false } - oldBin := bin + "_previous" + oldBin := bin + oldBinSuffix baseUrl := upgradeChannel[upgradeChannelDefault] if len(args) > 0 { channel := args[0] diff --git a/cmd/cli/main.go b/cmd/cli/main.go index 3c890f2..7c2894a 100644 --- a/cmd/cli/main.go +++ b/cmd/cli/main.go @@ -64,8 +64,11 @@ func Main() { } func normalizeLogFilePath(logFilePath string) string { - if logFilePath == "" || filepath.IsAbs(logFilePath) || service.Interactive() { - return logFilePath + // In cleanup mode, we always want the full log file path. + if !cleanup { + if logFilePath == "" || filepath.IsAbs(logFilePath) || service.Interactive() { + return logFilePath + } } if homedir != "" { return filepath.Join(homedir, logFilePath) @@ -122,14 +125,14 @@ func initLoggingWithBackup(doBackup bool) { flags := os.O_CREATE | os.O_RDWR | os.O_APPEND if doBackup { // Backup old log file with .1 suffix. - if err := os.Rename(logFilePath, logFilePath+".1"); err != nil && !os.IsNotExist(err) { + if err := os.Rename(logFilePath, logFilePath+oldLogSuffix); err != nil && !os.IsNotExist(err) { mainLog.Load().Error().Msgf("could not backup old log file: %v", err) } else { // Backup was created, set flags for truncating old log file. flags = os.O_CREATE | os.O_RDWR } } - logFile, err := os.OpenFile(logFilePath, flags, os.FileMode(0o600)) + logFile, err := openLogFile(logFilePath, flags) if err != nil { mainLog.Load().Error().Msgf("failed to create log file: %v", err) os.Exit(1) diff --git a/cmd/cli/service_others.go b/cmd/cli/service_others.go index e9522f4..f4d73e5 100644 --- a/cmd/cli/service_others.go +++ b/cmd/cli/service_others.go @@ -9,3 +9,7 @@ import ( func hasElevatedPrivilege() (bool, error) { return os.Geteuid() == 0, nil } + +func openLogFile(path string, flags int) (*os.File, error) { + return os.OpenFile(path, flags, os.FileMode(0o600)) +} diff --git a/cmd/cli/service_windows.go b/cmd/cli/service_windows.go index a1010a8..d4e2449 100644 --- a/cmd/cli/service_windows.go +++ b/cmd/cli/service_windows.go @@ -1,6 +1,11 @@ package cli -import "golang.org/x/sys/windows" +import ( + "os" + "syscall" + + "golang.org/x/sys/windows" +) func hasElevatedPrivilege() (bool, error) { var sid *windows.SID @@ -22,3 +27,55 @@ func hasElevatedPrivilege() (bool, error) { token := windows.Token(0) return token.IsMember(sid) } + +func openLogFile(path string, mode int) (*os.File, error) { + if len(path) == 0 { + return nil, &os.PathError{Path: path, Op: "open", Err: syscall.ERROR_FILE_NOT_FOUND} + } + + pathP, err := syscall.UTF16PtrFromString(path) + if err != nil { + return nil, err + } + var access uint32 + switch mode & (os.O_RDONLY | os.O_WRONLY | os.O_RDWR) { + case os.O_RDONLY: + access = windows.GENERIC_READ + case os.O_WRONLY: + access = windows.GENERIC_WRITE + case os.O_RDWR: + access = windows.GENERIC_READ | windows.GENERIC_WRITE + } + if mode&os.O_CREATE != 0 { + access |= windows.GENERIC_WRITE + } + if mode&os.O_APPEND != 0 { + access &^= windows.GENERIC_WRITE + access |= windows.FILE_APPEND_DATA + } + + shareMode := uint32(syscall.FILE_SHARE_READ | syscall.FILE_SHARE_WRITE | syscall.FILE_SHARE_DELETE) + + var sa *syscall.SecurityAttributes + + var createMode uint32 + switch { + case mode&(os.O_CREATE|os.O_EXCL) == (os.O_CREATE | os.O_EXCL): + createMode = windows.CREATE_NEW + case mode&(os.O_CREATE|os.O_TRUNC) == (os.O_CREATE | os.O_TRUNC): + createMode = windows.CREATE_ALWAYS + case mode&os.O_CREATE == os.O_CREATE: + createMode = windows.OPEN_ALWAYS + case mode&os.O_TRUNC == os.O_TRUNC: + createMode = windows.TRUNCATE_EXISTING + default: + createMode = windows.OPEN_EXISTING + } + + handle, err := syscall.CreateFile(pathP, access, shareMode, sa, createMode, syscall.FILE_ATTRIBUTE_NORMAL, 0) + if err != nil { + return nil, &os.PathError{Path: path, Op: "open", Err: err} + } + + return os.NewFile(uintptr(handle), path), nil +}