From 271e0524817dfcc67415d19202288172f6264208 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E1=B4=8D=E1=B4=8F=E1=B4=8F=C9=B4D4=CA=80=E1=B4=8B?= Date: Wed, 17 Jul 2024 22:32:29 +0800 Subject: [PATCH] fix: Improve error handling and fix wrong compress dir for windows (#367) (#368) * fix: Improve error handling and fix wrong compress dir for windows * refactor: Refactor fileutil package for pass linter --- cmd/hack-browser-data/main.go | 2 +- utils/fileutil/filetutil.go | 92 +++++++++++++++++++++------------ utils/fileutil/fileutil_test.go | 52 +++++++++++++++++++ 3 files changed, 112 insertions(+), 34 deletions(-) create mode 100644 utils/fileutil/fileutil_test.go diff --git a/cmd/hack-browser-data/main.go b/cmd/hack-browser-data/main.go index cfa58ba..d10ae85 100644 --- a/cmd/hack-browser-data/main.go +++ b/cmd/hack-browser-data/main.go @@ -62,7 +62,7 @@ func Execute() { if compress { if err = fileutil.CompressDir(outputDir); err != nil { - slog.Error("compress error: ", "err", err) + slog.Error("compress error", "err", err) } slog.Info("compress success") } diff --git a/utils/fileutil/filetutil.go b/utils/fileutil/filetutil.go index e1e5d95..4197f62 100644 --- a/utils/fileutil/filetutil.go +++ b/utils/fileutil/filetutil.go @@ -5,7 +5,6 @@ import ( "bytes" "fmt" "os" - "path" "path/filepath" "strings" @@ -94,40 +93,67 @@ func ParentBaseDir(p string) string { func CompressDir(dir string) error { files, err := os.ReadDir(dir) if err != nil { - return err + return fmt.Errorf("read dir error: %w", err) } - b := new(bytes.Buffer) - zw := zip.NewWriter(b) - for _, f := range files { - fw, err := zw.Create(f.Name()) - if err != nil { - return err - } - name := path.Join(dir, f.Name()) - content, err := os.ReadFile(name) - if err != nil { - return err - } - _, err = fw.Write(content) - if err != nil { - return err - } - err = os.Remove(name) - if err != nil { - return err + if len(files) == 0 { + // Return an error if no files are found in the directory + return fmt.Errorf("no files to compress in: %s", dir) + } + + buffer := new(bytes.Buffer) + zipWriter := zip.NewWriter(buffer) + defer func() { + _ = zipWriter.Close() + }() + + for _, file := range files { + if err := addFileToZip(zipWriter, filepath.Join(dir, file.Name())); err != nil { + return fmt.Errorf("failed to add file to zip: %w", err) } } - if err := zw.Close(); err != nil { - return err + + if err := zipWriter.Close(); err != nil { + return fmt.Errorf("error closing zip writer: %w", err) } - filename := filepath.Join(dir, fmt.Sprintf("%s.zip", dir)) - outFile, err := os.Create(filepath.Clean(filename)) - if err != nil { - return err - } - _, err = b.WriteTo(outFile) - if err != nil { - return err - } - return outFile.Close() + + zipFilename := filepath.Join(dir, filepath.Base(dir)+".zip") + return writeFile(buffer, zipFilename) +} + +func addFileToZip(zw *zip.Writer, filename string) error { + content, err := os.ReadFile(filename) + if err != nil { + return fmt.Errorf("error reading file %s: %w", filename, err) + } + + fw, err := zw.Create(filepath.Base(filename)) + if err != nil { + return fmt.Errorf("error creating zip entry for %s: %w", filename, err) + } + + if _, err = fw.Write(content); err != nil { + return fmt.Errorf("error writing content to zip for %s: %w", filename, err) + } + + if err = os.Remove(filename); err != nil { + return fmt.Errorf("error removing original file %s: %w", filename, err) + } + + return nil +} + +func writeFile(buffer *bytes.Buffer, filename string) error { + outFile, err := os.Create(filename) + if err != nil { + return fmt.Errorf("error creating output file %s: %w", filename, err) + } + defer func() { + _ = outFile.Close() + }() + + if _, err = buffer.WriteTo(outFile); err != nil { + return fmt.Errorf("error writing data to file %s: %w", filename, err) + } + + return nil } diff --git a/utils/fileutil/fileutil_test.go b/utils/fileutil/fileutil_test.go new file mode 100644 index 0000000..6e61d00 --- /dev/null +++ b/utils/fileutil/fileutil_test.go @@ -0,0 +1,52 @@ +package fileutil + +import ( + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func setupTestDir(t *testing.T, files []string) string { + t.Helper() // Marks the function as a helper function. + + tempDir, err := os.MkdirTemp("", "testCompressDir") + require.NoError(t, err, "failed to create a temporary directory") + + for _, file := range files { + filePath := filepath.Join(tempDir, file) + err := os.WriteFile(filePath, []byte("test content"), 0o644) + require.NoError(t, err, "failed to create a test file") + } + return tempDir +} + +func TestCompressDir(t *testing.T) { + t.Run("Normal Operation", func(t *testing.T) { + tempDir := setupTestDir(t, []string{"file1.txt", "file2.txt", "file3.txt"}) + defer os.RemoveAll(tempDir) + + err := CompressDir(tempDir) + assert.NoError(t, err, "compressDir should not return an error") + + // Check if the zip file exists + zipFile := filepath.Join(tempDir, filepath.Base(tempDir)+".zip") + assert.FileExists(t, zipFile, "zip file should be created") + }) + + t.Run("Directory Does Not Exist", func(t *testing.T) { + err := CompressDir("/path/to/nonexistent/directory") + assert.Error(t, err, "should return an error for non-existent directory") + }) + + t.Run("Empty Directory", func(t *testing.T) { + tempDir, err := os.MkdirTemp("", "testEmptyDir") + require.NoError(t, err, "failed to create empty test directory") + defer os.RemoveAll(tempDir) + + err = CompressDir(tempDir) + assert.Error(t, err, "should return an error for an empty directory") + }) +}