Skip to content

Instantly share code, notes, and snippets.

@masseelch
Last active March 26, 2021 09:04
Show Gist options
  • Save masseelch/057a5d0b64c75596b6184048d4d0622b to your computer and use it in GitHub Desktop.
Save masseelch/057a5d0b64c75596b6184048d4d0622b to your computer and use it in GitHub Desktop.
Canceable Go Downloader With Progress Report
package client
import (
"comlink"
"context"
"encoding/json"
"fmt"
"github.com/recws-org/recws"
"github.com/spf13/viper"
"io"
"net/http"
"net/url"
"os"
"path/filepath"
"sync"
"time"
)
const downloadInProgressExtension = ".tmp"
type (
Progress struct {
ws *recws.RecConn
msg comlink.Message
Filename string `json:"filename"`
Total uint64 `json:"total"`
Sent uint64 `json:"sent"`
}
Downloader struct {
ws *recws.RecConn
dls map[string]context.CancelFunc
dlsLock *sync.Mutex
}
ctxReader struct {
ctx context.Context
r io.Reader
}
)
func NewDownloader(ws *recws.RecConn) *Downloader {
return &Downloader{
ws: ws,
dls: make(map[string]context.CancelFunc),
dlsLock: &sync.Mutex{},
}
}
func (d *Downloader) Cancel(msg comlink.Message) {
if cancel, ok := d.dls[string(msg.Payload)]; ok {
cancel()
}
}
// Download will download a url to a local file. It's efficient because it will
// write as it downloads and not load the whole file into memory.
func (d *Downloader) Download(msg comlink.Message) error {
base := string(msg.Payload)
filename := filepath.Join(viper.GetString("pp.media"), base)
tmpFilename := filename + downloadInProgressExtension
// Create the file, but give it a tmp file extension, this means we won't overwrite a
// file until it's downloaded, but we'll remove the tmp extension once downloaded.
out, err := os.Create(tmpFilename)
if err != nil {
return err
}
// Get the data from the server.
u := url.URL{
Scheme: viper.GetString("server.scheme"),
Host: viper.GetString("server.host"),
Path: "/uploads/" + base,
}
resp, err := http.Get(u.String())
if err != nil {
out.Close()
return err
}
defer resp.Body.Close()
// Have our ctx-aware reader able us to cancel the download.
ctx, cancel := context.WithCancel(context.Background())
ctxR := &ctxReader{
ctx: ctx,
r: resp.Body,
}
defer d.dlsLock.Unlock()
d.dlsLock.Lock()
d.dls[base] = cancel
d.dlsLock.Unlock()
// Create a Progress to count the already downloaded data.
pr := &Progress{
ws: d.ws,
msg: msg.WithType(comlink.Progress),
Filename: base,
Total: uint64(resp.ContentLength),
}
// Start a timer to sent the current download progress to the server once a second.
ticker := time.NewTicker(time.Second)
done := make(chan struct{})
defer func() {
// Break the ticker loop.
close(done)
// Stop the ticker routine.
ticker.Stop()
}()
go func() {
for {
select {
case <-ticker.C:
pr.ReportProgress()
case <-done:
return
}
}
}()
// Make sure all resources get cleaned up when the download ends (if aborted or not).
defer func() {
d.dlsLock.Lock()
delete(d.dls, base)
d.dlsLock.Unlock()
}()
// Download to file.
if _, err := io.Copy(out, io.TeeReader(ctxR, pr)); err != nil {
out.Close()
// Remove the tmp file.
os.Remove(tmpFilename)
return err
}
// Close the file without defer so it can happen before Rename()
out.Close()
if err = os.Rename(tmpFilename, filename); err != nil {
return err
}
// Report the progress once more so the server / managers know the download is complete.
pr.ReportProgress()
return nil
}
// Write will count the bytes that went through it. It will report the progress on the given
// websocket after it received threshold bytes since the last report.
func (p *Progress) Write(b []byte) (int, error) {
n := len(b)
p.Sent += uint64(n)
time.Sleep(time.Second)
return n, nil
}
func (p Progress) ReportProgress() { // TODO: Error handling
j, _ := json.Marshal(p)
fmt.Printf("Received %d of %d\n", p.Sent, p.Total)
_ = p.ws.WriteJSON(p.msg.WithPayload(j))
}
func (r ctxReader) Read(p []byte) (int, error) {
if err := r.ctx.Err(); err != nil {
return 0, err
}
return r.r.Read(p)
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment