Skip to content

Instantly share code, notes, and snippets.

@morgangallant
Created August 10, 2020 02:55
Show Gist options
  • Save morgangallant/e7f9d562f1639fca35a3c569d6739521 to your computer and use it in GitHub Desktop.
Save morgangallant/e7f9d562f1639fca35a3c569d6739521 to your computer and use it in GitHub Desktop.
automatic self-updating of go binaries from scratch
// This is a seperate binary that you use when you are deploying your primary binary (main.go).
package main
import (
"fmt"
"log"
"operand/pkg/update"
"os/exec"
"github.com/pkg/errors"
)
func main() {
if err := run(); err != nil {
log.Fatal(err)
}
}
func run() error {
rver, err := update.RemoteVersion("your-app-name")
if err != nil {
return err
}
nv := rver + 1
if err := compile(nv); err != nil {
return err
}
if err := update.Distribute("your-app-name", "./your-app-name", nv); err != nil {
return err
}
log.Printf("Successfully distributed version %d.", nv)
return nil
}
func compile(version int) error {
varg := fmt.Sprintf("-ldflags=-X=main.vstring=%d", version)
const file = "cmd/your-app-name/your-app-name.go"
cmd := exec.Command("go", "build", varg, "-o", "your-app-name", file)
out, err := cmd.CombinedOutput()
if err != nil {
fmt.Print(string(out))
return errors.Wrap(err, "failed to compile")
}
return nil
}
// This is the main package of your application you want to be self-updating.
// Right now, if an update can be done, this application exits after hot-swapping the new binary.
package main
import (
"log"
"math/rand"
"operand/pkg/update"
"os"
"strconv"
"time"
)
// These are set by the go linker & init(), don't touch them.
var (
vstring string
version int
)
func init() {
var err error
version, err = strconv.Atoi(vstring)
if err != nil {
panic(err)
}
}
func main() {
if err := run(); err != nil {
log.Fatal(err)
}
}
func run() error {
rand.Seed(time.Now().UnixNano())
pending, ver, err := update.Check("your-app-name", version)
if err != nil {
return err
}
if pending {
log.Printf("There is a pending update!")
if err := update.UpdateTo("your-app-name", ver); err != nil {
return err
}
log.Printf("Successfully updated.")
os.Exit(0)
} else {
log.Printf("up to date")
}
return nil
}
// This is the self-updating package which facilitates binary self-updating.
package update
import (
"archive/tar"
"bytes"
"compress/gzip"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"fmt"
"io"
"math/rand"
"os"
"path/filepath"
"strings"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/s3"
"github.com/aws/aws-sdk-go/service/s3/s3manager"
"github.com/inconshreveable/go-update"
"github.com/pkg/errors"
)
// DON'T ACTUALLY DO THIS
const (
bucket = "bucket-name-here"
region = "nyc3"
endpoint = region + ".digitaloceanspaces.com/"
key = "key-goes-here"
secret = "super-secret"
)
// DON'T ACTUALLY DO THIS
var (
sess = session.Must(session.NewSession(&aws.Config{
Endpoint: aws.String(endpoint),
Region: aws.String(region),
Credentials: credentials.NewStaticCredentials(key, secret, ""),
}))
objs = s3.New(sess, &aws.Config{
DisableRestProtocolURICleaning: aws.Bool(true),
})
uploader = s3manager.NewUploader(sess)
downloader = s3manager.NewDownloader(sess)
)
// RemoteVersion returns the remote version of this binary. Returns -1 if it
// doesn't exist on the remote store.
func RemoteVersion(application string) (int, error) {
list, err := objs.ListObjects(&s3.ListObjectsInput{
Bucket: aws.String(bucket),
Prefix: aws.String(application + "/"),
})
if err != nil {
return -1, errors.Wrapf(err, "failed to list objs for %s", application)
}
var count int
for _, obj := range list.Contents {
if strings.HasSuffix(*obj.Key, "/") ||
!strings.HasSuffix(*obj.Key, ".tar.gz") {
continue
}
count++
}
return count - 1, nil
}
func validRelativePath(p string) bool {
return p != "" &&
!strings.Contains(p, `\`) &&
!strings.HasPrefix(p, "/") &&
!strings.Contains(p, "../")
}
func decompress(r io.Reader, dst string) error {
zr, err := gzip.NewReader(r)
if err != nil {
return errors.Wrap(err, "failed to create gzip reader")
}
defer zr.Close()
tr := tar.NewReader(zr)
for {
header, err := tr.Next()
if err == io.EOF {
break
}
if err != nil {
return errors.Wrap(err, "file read failure")
}
if !validRelativePath(header.Name) {
return errors.New("invalid file name in tar " + header.Name)
}
target := filepath.Join(dst, header.Name)
switch header.Typeflag {
case tar.TypeDir:
// We assume that the tar has no subdirectories, and since we only want to
// extract the files into the dst/ directory, we don't extract any dir.
continue
case tar.TypeReg:
const flags = os.O_CREATE | os.O_RDWR
fw, err := os.OpenFile(target, flags, os.FileMode(header.Mode))
if err != nil {
return errors.Wrapf(err, "failed to create dst file for %s", target)
}
if _, err := io.Copy(fw, tr); err != nil {
return errors.Wrapf(err, "failed to decompress file %s", target)
}
fw.Close()
}
}
return nil
}
func tmpFileName(tlen int, folder bool) string {
var tchars = []rune("abcdefghijklmnopqrstuvwxyz")
var b strings.Builder
for i := 0; i < tlen; i++ {
b.WriteRune(tchars[rand.Intn(len(tchars))])
}
extra := ""
if folder {
extra = "/"
}
return filepath.Join("/tmp", "operand-"+b.String()+extra)
}
func withTemporaryFile(exec func(*os.File) error) error {
tf := tmpFileName(6, false)
f, err := os.Create(tf)
if err != nil {
return errors.Wrap(err, "failed to create temporary file")
}
defer os.RemoveAll(tf)
defer f.Close()
return exec(f)
}
func downloadRemote(application string, version int, dst string) error {
return withTemporaryFile(func(f *os.File) error {
if _, err := os.Stat(dst); os.IsNotExist(err) {
if err := os.MkdirAll(dst, 0755); err != nil {
return errors.Wrap(err, "failed to create dst directory")
}
}
if _, err := downloader.Download(f, &s3.GetObjectInput{
Bucket: aws.String(bucket),
Key: aws.String(fmt.Sprintf("%s/%d.tar.gz", application, version)),
}); err != nil {
return errors.Wrap(err, "failed to download remote gzip")
}
return errors.Wrap(decompress(f, dst), "failed to decompress gzip")
})
}
func hashFileContents(r io.Reader) (string, error) {
h := sha256.New()
if _, err := io.Copy(h, r); err != nil {
return "", errors.Wrap(err, "failed to write data to hasher")
}
return hex.EncodeToString(h.Sum(nil)), nil
}
func hashFile(path string) (string, error) {
f, err := os.Open(path)
if err != nil {
return "", errors.Wrap(err, "failed to open file")
}
defer f.Close()
return hashFileContents(f)
}
type meta struct {
SHA256 string `json:"sha256"`
}
func uploadFile(key string, body io.Reader) error {
_, err := uploader.Upload(&s3manager.UploadInput{
Bucket: aws.String(bucket),
Key: aws.String(key),
Body: body,
})
return errors.Wrap(err, "failed to upload file")
}
func copyFile(src, dst string) error {
sinfo, err := os.Stat(src)
if err != nil {
return err
}
sf, err := os.Open(src)
if err != nil {
return err
}
defer sf.Close()
df, err := os.Create(dst)
if err != nil {
return err
}
defer df.Close()
if _, err := io.Copy(df, sf); err != nil {
return err
}
if err := os.Chmod(dst, sinfo.Mode()); err != nil {
return err
}
return df.Sync()
}
func withTemporaryFolder(exec func(string) error) error {
tf := tmpFileName(6, true)
if err := os.MkdirAll(tf, 0755); err != nil {
return errors.Wrap(err, "failed to create temporary directory")
}
defer os.RemoveAll(tf)
return exec(tf)
}
func compress(rpath string, dst io.Writer) error {
zr := gzip.NewWriter(dst)
tw := tar.NewWriter(zr)
err := filepath.Walk(rpath, func(path string, info os.FileInfo, err error) error {
if err != nil {
return err
}
if info.IsDir() {
return nil
}
hdr, err := tar.FileInfoHeader(info, path)
if err != nil {
return err
}
hdr.Name = filepath.Base(path)
if err := tw.WriteHeader(hdr); err != nil {
return err
}
data, err := os.Open(path)
if err != nil {
return err
}
_, err = io.Copy(tw, data)
return err
})
if err != nil {
return errors.Wrap(err, "failed to walk filepath")
}
if err := tw.Close(); err != nil {
return err
}
if err := zr.Close(); err != nil {
return err
}
return nil
}
// Distribute a binary located at bpath for a given application. The data is
// stored in a DO storage bucket.
func Distribute(application string, bpath string, ver int) error {
return withTemporaryFolder(func(fpath string) error {
bdst := filepath.Join(fpath, filepath.Base(bpath))
if err := copyFile(bpath, bdst); err != nil {
return errors.Wrap(err, "failed to copy binary file to temporary dir")
}
mf, err := os.Create(filepath.Join(fpath, "meta.json"))
if err != nil {
return errors.Wrap(err, "failed to create metadata file")
}
hash, err := hashFile(bpath)
if err != nil {
return errors.Wrap(err, "failed to hash binary file")
}
if err := json.NewEncoder(mf).Encode(meta{
SHA256: hash,
}); err != nil {
return errors.Wrap(err, "failed to encode metadata json")
}
mf.Close()
var tzdata bytes.Buffer
if err := compress(fpath, &tzdata); err != nil {
return errors.Wrap(err, "failed to compress folder")
}
return uploadFile(fmt.Sprintf("%s/%d.tar.gz", application, ver), &tzdata)
})
}
// Check for any pending updates. If the remote version is greater than our
// current version, then this function will return true.
func Check(application string, current int) (bool, int, error) {
rver, err := RemoteVersion(application)
if err != nil {
return false, -1, errors.Wrap(err, "failed to get remote version")
}
return rver > current, rver, nil
}
// UpdateTo will migrate this application to a specific version by hotswapping
// this binary to the new version.
func UpdateTo(application string, version int) error {
dir, err := os.Executable()
if err != nil {
return errors.Wrap(err, "failed to get path of executable")
}
dir = strings.TrimSuffix(dir, application)
sname := filepath.Join(dir, "operand-"+string(version))
if err := downloadRemote(application, version, sname); err != nil {
return errors.Wrap(err, "failed to download remote version")
}
mf, err := os.Open(filepath.Join(sname, "meta.json"))
if err != nil {
return errors.Wrap(err, "failed to open meta.json")
}
var m meta
if err := json.NewDecoder(mf).Decode(&m); err != nil {
return errors.Wrap(err, "failed to decode meta.json")
}
_ = mf.Close()
bpath := filepath.Join(sname, application)
hash, err := hashFile(bpath)
if err != nil {
return errors.Wrap(err, "failed to hash binary")
}
if hash != m.SHA256 {
return errors.New("hash does not match")
}
b, err := os.Open(bpath)
if err != nil {
return errors.Wrap(err, "failed to open binary file")
}
defer b.Close()
if err := update.Apply(b, update.Options{}); err != nil {
if rerr := update.RollbackError(err); rerr != nil {
return errors.Wrap(rerr, "failed to rollback from failed update")
}
return errors.Wrap(err, "failed to apply, yet rolled back successfully")
}
return os.RemoveAll(sname)
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment