Last active
March 26, 2021 09:04
-
-
Save masseelch/057a5d0b64c75596b6184048d4d0622b to your computer and use it in GitHub Desktop.
Canceable Go Downloader With Progress Report
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package client | |
import ( | |
"comlink" | |
"context" | |
"encoding/json" | |
"fmt" | |
"github.com/recws-org/recws" | |
"github.com/spf13/viper" | |
"io" | |
"net/http" | |
"net/url" | |
"os" | |
"path/filepath" | |
"sync" | |
"time" | |
) | |
const downloadInProgressExtension = ".tmp" | |
type ( | |
Progress struct { | |
ws *recws.RecConn | |
msg comlink.Message | |
Filename string `json:"filename"` | |
Total uint64 `json:"total"` | |
Sent uint64 `json:"sent"` | |
} | |
Downloader struct { | |
ws *recws.RecConn | |
dls map[string]context.CancelFunc | |
dlsLock *sync.Mutex | |
} | |
ctxReader struct { | |
ctx context.Context | |
r io.Reader | |
} | |
) | |
func NewDownloader(ws *recws.RecConn) *Downloader { | |
return &Downloader{ | |
ws: ws, | |
dls: make(map[string]context.CancelFunc), | |
dlsLock: &sync.Mutex{}, | |
} | |
} | |
func (d *Downloader) Cancel(msg comlink.Message) { | |
if cancel, ok := d.dls[string(msg.Payload)]; ok { | |
cancel() | |
} | |
} | |
// Download will download a url to a local file. It's efficient because it will | |
// write as it downloads and not load the whole file into memory. | |
func (d *Downloader) Download(msg comlink.Message) error { | |
base := string(msg.Payload) | |
filename := filepath.Join(viper.GetString("pp.media"), base) | |
tmpFilename := filename + downloadInProgressExtension | |
// Create the file, but give it a tmp file extension, this means we won't overwrite a | |
// file until it's downloaded, but we'll remove the tmp extension once downloaded. | |
out, err := os.Create(tmpFilename) | |
if err != nil { | |
return err | |
} | |
// Get the data from the server. | |
u := url.URL{ | |
Scheme: viper.GetString("server.scheme"), | |
Host: viper.GetString("server.host"), | |
Path: "/uploads/" + base, | |
} | |
resp, err := http.Get(u.String()) | |
if err != nil { | |
out.Close() | |
return err | |
} | |
defer resp.Body.Close() | |
// Have our ctx-aware reader able us to cancel the download. | |
ctx, cancel := context.WithCancel(context.Background()) | |
ctxR := &ctxReader{ | |
ctx: ctx, | |
r: resp.Body, | |
} | |
defer d.dlsLock.Unlock() | |
d.dlsLock.Lock() | |
d.dls[base] = cancel | |
d.dlsLock.Unlock() | |
// Create a Progress to count the already downloaded data. | |
pr := &Progress{ | |
ws: d.ws, | |
msg: msg.WithType(comlink.Progress), | |
Filename: base, | |
Total: uint64(resp.ContentLength), | |
} | |
// Start a timer to sent the current download progress to the server once a second. | |
ticker := time.NewTicker(time.Second) | |
done := make(chan struct{}) | |
defer func() { | |
// Break the ticker loop. | |
close(done) | |
// Stop the ticker routine. | |
ticker.Stop() | |
}() | |
go func() { | |
for { | |
select { | |
case <-ticker.C: | |
pr.ReportProgress() | |
case <-done: | |
return | |
} | |
} | |
}() | |
// Make sure all resources get cleaned up when the download ends (if aborted or not). | |
defer func() { | |
d.dlsLock.Lock() | |
delete(d.dls, base) | |
d.dlsLock.Unlock() | |
}() | |
// Download to file. | |
if _, err := io.Copy(out, io.TeeReader(ctxR, pr)); err != nil { | |
out.Close() | |
// Remove the tmp file. | |
os.Remove(tmpFilename) | |
return err | |
} | |
// Close the file without defer so it can happen before Rename() | |
out.Close() | |
if err = os.Rename(tmpFilename, filename); err != nil { | |
return err | |
} | |
// Report the progress once more so the server / managers know the download is complete. | |
pr.ReportProgress() | |
return nil | |
} | |
// Write will count the bytes that went through it. It will report the progress on the given | |
// websocket after it received threshold bytes since the last report. | |
func (p *Progress) Write(b []byte) (int, error) { | |
n := len(b) | |
p.Sent += uint64(n) | |
time.Sleep(time.Second) | |
return n, nil | |
} | |
func (p Progress) ReportProgress() { // TODO: Error handling | |
j, _ := json.Marshal(p) | |
fmt.Printf("Received %d of %d\n", p.Sent, p.Total) | |
_ = p.ws.WriteJSON(p.msg.WithPayload(j)) | |
} | |
func (r ctxReader) Read(p []byte) (int, error) { | |
if err := r.ctx.Err(); err != nil { | |
return 0, err | |
} | |
return r.r.Read(p) | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment