Skip to content

Instantly share code, notes, and snippets.

@hungneox
Created December 26, 2017 22:43
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save hungneox/d22282a38339d111245e6a46a3c65e42 to your computer and use it in GitHub Desktop.
Save hungneox/d22282a38339d111245e6a46a3c65e42 to your computer and use it in GitHub Desktop.
Download file with golang
package main
import (
"fmt"
"io"
"net/http"
"os"
"path/filepath"
"runtime"
"sort"
"strconv"
"sync"
"github.com/fatih/color"
"gopkg.in/cheggaaa/pb.v1"
)
/*
Part struct
*/
type Part struct {
URL string
Path string
RangeFrom int64
RangeTo int64
}
var (
client http.Client
)
var (
acceptRangeHeader = "Accept-Ranges"
contentLengthHeader = "Content-Length"
)
const downloadFolder = "./parts"
const fileName = "go1.9.2.darwin-amd64.pkg"
var fileChan = make(chan string, int64(runtime.NumCPU()))
var doneChan = make(chan bool, int64(runtime.NumCPU()))
func handleError(err error) {
if err != nil {
err := fmt.Errorf("%v", err)
panic(err)
}
}
func getHeader(url string) *http.Response {
req, err := http.NewRequest("GET", url, nil)
handleError(err)
resp, err := client.Do(req)
handleError(err)
if resp.Header.Get(acceptRangeHeader) == "" {
fmt.Printf("Response does not contain Accept-Ranges header\n")
os.Exit(1)
}
if resp.Header.Get(contentLengthHeader) == "" {
fmt.Printf("Response does not contain Content-Length header\n")
os.Exit(1)
}
return resp
}
func calculateParts(cnn int64, len int64, url string) []Part {
ret := make([]Part, 0)
for j := int64(0); j < cnn; j++ {
from := (len / cnn) * j
var to int64
if j < cnn-1 {
to = (len/cnn)*(j+1) - 1
} else {
to = len
}
file := "go1.9.2.darwin-amd64.pkg"
fname := fmt.Sprintf("%s.part%d", file, j)
path := filepath.Join(downloadFolder, fname)
ret = append(ret, Part{URL: url, Path: path, RangeFrom: from, RangeTo: to})
}
return ret
}
func joinFile(files []string, out string) error {
//sort with file name or we will join files with wrong order
sort.Strings(files)
var bar *pb.ProgressBar
fmt.Printf("Start joining \n")
bar = pb.StartNew(len(files)).Prefix(color.CyanString("Joining"))
outf, err := os.OpenFile(out, os.O_CREATE|os.O_WRONLY, 0600)
defer outf.Close()
if err != nil {
return err
}
for _, f := range files {
if err = copy(f, outf); err != nil {
return err
}
bar.Increment()
}
bar.Finish()
return nil
}
//this function split just to use defer
func copy(from string, to io.Writer) error {
f, err := os.OpenFile(from, os.O_RDONLY, 0600)
defer f.Close()
if err != nil {
return err
}
io.Copy(to, f)
return nil
}
func do(parts []Part, conn int64) {
var ws sync.WaitGroup
var bars []*pb.ProgressBar
var barpool *pb.Pool
var err error
bars = make([]*pb.ProgressBar, 0)
for i, part := range parts {
newbar := pb.New64(part.RangeTo - part.RangeFrom).SetUnits(pb.U_BYTES).Prefix(color.YellowString(fmt.Sprintf("%s-%d", fileName, i)))
bars = append(bars, newbar)
}
barpool, err = pb.StartPool(bars...)
handleError(err)
for i, p := range parts {
ws.Add(1)
go func(filename string, ith int64, part Part) {
defer ws.Done()
var bar *pb.ProgressBar
bar = bars[ith]
fmt.Printf("Part %d\n", ith)
ranges := fmt.Sprintf("bytes=%d-%d", part.RangeFrom, part.RangeTo)
req, err := http.NewRequest("GET", part.URL, nil)
handleError(err)
req.Header.Add("Range", ranges)
resp, err := client.Do(req)
handleError(err)
defer resp.Body.Close()
f, err := os.OpenFile(part.Path, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0600)
fileChan <- part.Path
defer f.Close()
handleError(err)
current := int64(0)
var writer io.Writer
writer = io.MultiWriter(f, bar)
for {
written, err := io.CopyN(writer, resp.Body, 100)
current += written
if err != nil {
bar.Finish()
return
}
}
}(fileName, int64(i), p)
}
ws.Wait()
doneChan <- true
barpool.Stop()
}
// https://blog.golang.org/pipelines
func main() {
conn := runtime.NumCPU()
url := "https://storage.googleapis.com/golang/go1.9.2.darwin-amd64.pkg"
var files = make([]string, 0)
response := getHeader(url)
contentLength := response.Header.Get(contentLengthHeader)
fmt.Printf("Start download with %d connections \n", conn)
length, err := strconv.ParseInt(contentLength, 10, 64)
handleError(err)
var parts = calculateParts(int64(conn), length, url)
go do(parts, int64(conn))
for {
select {
case file := <-fileChan:
files = append(files, file)
case <-doneChan:
err = joinFile(files, fileName)
return
}
}
}