Skip to content

Instantly share code, notes, and snippets.

@RaghavSood
Created July 21, 2021 19:04
Show Gist options
  • Save RaghavSood/cbdc18cbc082a4c44cf73cacf0c72b5d to your computer and use it in GitHub Desktop.
Save RaghavSood/cbdc18cbc082a4c44cf73cacf0c72b5d to your computer and use it in GitHub Desktop.
package main
import (
"flag"
"fmt"
"io"
"io/ioutil"
"os"
"path/filepath"
"github.com/pkg/errors"
)
var (
src string
dst string
)
// Use the flag pacakge along with the init function to do cli flag inputs
// and kill the program early on if even the basic validations like not
// having enough data to even start copying a file fail
func init() {
flag.StringVar(&src, "src", "", "source file")
flag.StringVar(&dst, "dst", "", "destination file")
flag.Parse()
if src == "" {
fmt.Println("Source file cannot be blank")
os.Exit(1)
}
if dst == "" {
fmt.Println("Destination file cannot be blank")
os.Exit(1)
}
}
func main() {
// Your main function should generally do minimal work beyond starting
// your application and returning its overall results
copied, err := copyData(src, dst)
if err != nil {
fmt.Println("Failed to copy data:", err)
os.Exit(1)
}
if copied == -1 {
fmt.Println("Failed to copy data, unknown error")
os.Exit(1)
}
fmt.Printf("Copied %s to %s (%s)", src, dst, formatBytes(copied))
}
func copyData(source, destination string) (int64, error) {
// os.Stat returns a FileInfo object - in go, it is conventional
// to use shorthand variables like err, fi, etc. consisting of the
// first letters of the word, of the first letter of each word in the
// long name
fi, err := os.Stat(source)
if err != nil {
// Wrapping errors with additional context is useful
return -1, errors.Wrap(err, "could not stat source file")
}
// the fs package already defines constants for mode and has helper functions
// you should use that directly, rather tha introducing your own constants
// makes for more portable code that is more easily understood by someone already
// familiar with go, but not with your approach
switch mode := fi.Mode(); {
case mode.IsDir():
return copyDir(source, destination)
case mode.IsRegular():
return copyFile(source, destination)
default:
return -1, errors.New(fmt.Sprintf("unexpected file mode %d", mode))
}
}
func copyDir(source, destination string) (int64, error) {
source = filepath.Clean(source)
destination = filepath.Clean(destination)
// prefer MkdirAll over MkDir to cover cases where multiple folders
// don't exist
err := os.MkdirAll(destination, 0755)
if err != nil {
return -1, errors.Wrap(err, "could not create destination directory")
}
files, err := ioutil.ReadDir(source)
if err != nil {
return -1, errors.Wrap(err, "could not read source directory")
}
var bytesCopied int64
for _, file := range files {
srcPath := filepath.Join(source, file.Name())
dstPath := filepath.Join(destination, file.Name())
if file.IsDir() {
copied, err := copyDir(srcPath, dstPath)
bytesCopied += copied
if err != nil {
return bytesCopied, errors.Wrap(err, "failed while copying subdirectory")
}
} else {
copied, err := copyFile(srcPath, dstPath)
bytesCopied += copied
if err != nil {
return bytesCopied, errors.Wrap(err, fmt.Sprintf("could not copy file: %s", srcPath))
}
}
}
return bytesCopied, nil
}
func formatBytes(bytes int64) string {
const unit = int64(1000)
if bytes < unit {
return fmt.Sprintf("%d B", bytes)
}
exponent := 0
quotient := bytes
for n := bytes / unit; n >= unit; n /= unit {
quotient *= unit
exponent++
}
return fmt.Sprintf("%.1f %cB", float64(bytes)/float64(quotient), "kMGTPE"[exponent])
}
// you can define variable names in the function definition
func copyFile(source, destination string) (bytesCopied int64, err error) {
bytesCopied = -1
sourceFile, err := os.Open(source)
if err != nil {
return bytesCopied, errors.Wrap(err, "could not open source file")
}
// defer allows you to schedule code to run whenever the function returns
// in your original solution, if you hit the return in the destination file
// error if condition, your source file is never closed. using defer ensures
// that all code paths that terminate this function will run any deferred code
defer sourceFile.Close()
destinationFile, err := os.Create(destination)
if err != nil {
return bytesCopied, errors.Wrap(err, "could not create destination file")
}
// defer can also invoke entire functions, not just a single line - now, defer
// will actually run both of our defer instruction, including the sourceFile closure
defer func() {
if e := destinationFile.Close(); e != nil {
err = errors.Wrap(e, "could not close destination file")
}
}()
bytesCopied, err = io.Copy(destinationFile, sourceFile)
return bytesCopied, err
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment