-
-
Save RaghavSood/cbdc18cbc082a4c44cf73cacf0c72b5d to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package main | |
import ( | |
"flag" | |
"fmt" | |
"io" | |
"io/ioutil" | |
"os" | |
"path/filepath" | |
"github.com/pkg/errors" | |
) | |
var ( | |
src string | |
dst string | |
) | |
// Use the flag pacakge along with the init function to do cli flag inputs | |
// and kill the program early on if even the basic validations like not | |
// having enough data to even start copying a file fail | |
func init() { | |
flag.StringVar(&src, "src", "", "source file") | |
flag.StringVar(&dst, "dst", "", "destination file") | |
flag.Parse() | |
if src == "" { | |
fmt.Println("Source file cannot be blank") | |
os.Exit(1) | |
} | |
if dst == "" { | |
fmt.Println("Destination file cannot be blank") | |
os.Exit(1) | |
} | |
} | |
func main() { | |
// Your main function should generally do minimal work beyond starting | |
// your application and returning its overall results | |
copied, err := copyData(src, dst) | |
if err != nil { | |
fmt.Println("Failed to copy data:", err) | |
os.Exit(1) | |
} | |
if copied == -1 { | |
fmt.Println("Failed to copy data, unknown error") | |
os.Exit(1) | |
} | |
fmt.Printf("Copied %s to %s (%s)", src, dst, formatBytes(copied)) | |
} | |
func copyData(source, destination string) (int64, error) { | |
// os.Stat returns a FileInfo object - in go, it is conventional | |
// to use shorthand variables like err, fi, etc. consisting of the | |
// first letters of the word, of the first letter of each word in the | |
// long name | |
fi, err := os.Stat(source) | |
if err != nil { | |
// Wrapping errors with additional context is useful | |
return -1, errors.Wrap(err, "could not stat source file") | |
} | |
// the fs package already defines constants for mode and has helper functions | |
// you should use that directly, rather tha introducing your own constants | |
// makes for more portable code that is more easily understood by someone already | |
// familiar with go, but not with your approach | |
switch mode := fi.Mode(); { | |
case mode.IsDir(): | |
return copyDir(source, destination) | |
case mode.IsRegular(): | |
return copyFile(source, destination) | |
default: | |
return -1, errors.New(fmt.Sprintf("unexpected file mode %d", mode)) | |
} | |
} | |
func copyDir(source, destination string) (int64, error) { | |
source = filepath.Clean(source) | |
destination = filepath.Clean(destination) | |
// prefer MkdirAll over MkDir to cover cases where multiple folders | |
// don't exist | |
err := os.MkdirAll(destination, 0755) | |
if err != nil { | |
return -1, errors.Wrap(err, "could not create destination directory") | |
} | |
files, err := ioutil.ReadDir(source) | |
if err != nil { | |
return -1, errors.Wrap(err, "could not read source directory") | |
} | |
var bytesCopied int64 | |
for _, file := range files { | |
srcPath := filepath.Join(source, file.Name()) | |
dstPath := filepath.Join(destination, file.Name()) | |
if file.IsDir() { | |
copied, err := copyDir(srcPath, dstPath) | |
bytesCopied += copied | |
if err != nil { | |
return bytesCopied, errors.Wrap(err, "failed while copying subdirectory") | |
} | |
} else { | |
copied, err := copyFile(srcPath, dstPath) | |
bytesCopied += copied | |
if err != nil { | |
return bytesCopied, errors.Wrap(err, fmt.Sprintf("could not copy file: %s", srcPath)) | |
} | |
} | |
} | |
return bytesCopied, nil | |
} | |
func formatBytes(bytes int64) string { | |
const unit = int64(1000) | |
if bytes < unit { | |
return fmt.Sprintf("%d B", bytes) | |
} | |
exponent := 0 | |
quotient := bytes | |
for n := bytes / unit; n >= unit; n /= unit { | |
quotient *= unit | |
exponent++ | |
} | |
return fmt.Sprintf("%.1f %cB", float64(bytes)/float64(quotient), "kMGTPE"[exponent]) | |
} | |
// you can define variable names in the function definition | |
func copyFile(source, destination string) (bytesCopied int64, err error) { | |
bytesCopied = -1 | |
sourceFile, err := os.Open(source) | |
if err != nil { | |
return bytesCopied, errors.Wrap(err, "could not open source file") | |
} | |
// defer allows you to schedule code to run whenever the function returns | |
// in your original solution, if you hit the return in the destination file | |
// error if condition, your source file is never closed. using defer ensures | |
// that all code paths that terminate this function will run any deferred code | |
defer sourceFile.Close() | |
destinationFile, err := os.Create(destination) | |
if err != nil { | |
return bytesCopied, errors.Wrap(err, "could not create destination file") | |
} | |
// defer can also invoke entire functions, not just a single line - now, defer | |
// will actually run both of our defer instruction, including the sourceFile closure | |
defer func() { | |
if e := destinationFile.Close(); e != nil { | |
err = errors.Wrap(e, "could not close destination file") | |
} | |
}() | |
bytesCopied, err = io.Copy(destinationFile, sourceFile) | |
return bytesCopied, err | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment