Skip to content

Instantly share code, notes, and snippets.

@lrstanley
Created May 14, 2023 15:18
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save lrstanley/de40012216c188aede5b57fb6e4c2c88 to your computer and use it in GitHub Desktop.
Save lrstanley/de40012216c188aede5b57fb6e4c2c88 to your computer and use it in GitHub Desktop.
Example of self-updating from a nexus repository
package main
import (
"context"
"crypto/tls"
"fmt"
"io"
"net/http"
"os"
"regexp"
"runtime"
"sort"
"time"
"github.com/Masterminds/semver/v3"
"github.com/apex/log"
"github.com/fynelabs/selfupdate"
)
var reVersion = regexp.MustCompile(`(?m)<a href="(?P<version>\d+\.\d+\.\d+[a-zA-Z0-9._+~]*)/">`)
// Pulled from main.go
type Flags struct {
// Wrap global flags and configurations to add sub-command logic.
*models.Flags
Login CommandLogin `command:"login" description:"login to example"`
Update CommandUpdate `command:"update" description:"update example-cli"`
Get struct {
// TRUNCATED
} `command:"get" description:"get information about a resource"`
}
type CommandUpdate struct {
NexusURL string `long:"nexus-url" description:"URL to use for nexus" default:"https://nexus.example.com"`
NexusRepository string `long:"nexus-repository" description:"Nexus repository to use" default:"some-repository-name"`
TLSSkipVerify bool `long:"tls-skip-verify" description:"Skip TLS verification"`
AllowPrerelease bool `long:"allow-prerelease" description:"Allow updating to prerelease versions"`
client *http.Client
}
func (c *CommandUpdate) Execute(args []string) error {
c.initClient()
newVersion := c.needsUpdate(context.Background(), true)
if newVersion == nil {
logger.Info("example-cli is up to date")
os.Exit(0)
}
binary := fmt.Sprintf("example_%s_%s", runtime.GOOS, runtime.GOARCH)
if runtime.GOOS == "windows" {
binary += ".exe"
}
uri := fmt.Sprintf(
"%s/repository/%s/%s/%s/%s/%s",
c.NexusURL,
c.NexusRepository,
productName,
applicationName,
newVersion.Original(),
binary,
)
logger.WithField("uri", uri).Info("downloading update")
resp, err := c.client.Get(uri)
if err != nil {
logger.WithError(err).Fatal("error downloading update")
}
err = selfupdate.Apply(resp.Body, selfupdate.Options{})
resp.Body.Close()
if err != nil {
logger.WithError(err).Error("error applying update")
if rerr := selfupdate.RollbackError(err); rerr != nil {
logger.WithError(rerr).Error("error rolling back update")
}
os.Exit(1)
}
logger.WithField("new-version", newVersion.String()).Info("example-cli updated successfully")
return nil
}
func (c *CommandUpdate) initClient() {
if c.client != nil {
return
}
c.client = &http.Client{
Timeout: time.Second * 10,
Transport: &http.Transport{
TLSClientConfig: &tls.Config{
InsecureSkipVerify: c.TLSSkipVerify,
},
},
}
}
// needsUpdate checks if an update is available. If required is set, and an update
// check can't be done, this will exit the program with an error. If newVersion is
// nil, no update is available.
func (c *CommandUpdate) needsUpdate(ctx context.Context, required bool) *semver.Version {
c.initClient()
logger.WithField("current-version", version).Debug("checking for updates")
// Fetch versions.
current, versions, err := c.fetchVersions(ctx)
if err != nil {
if !required {
logger.WithError(err).Debug("error fetching versions")
return nil
}
logger.WithError(err).Fatal("error fetching versions")
}
// Check if we're up to date.
if !current.LessThan(versions[0]) {
return nil
}
return versions[0]
}
func (c *CommandUpdate) fetchVersions(ctx context.Context) (current *semver.Version, versions []*semver.Version, err error) {
current, err = semver.NewVersion(version)
if err != nil {
return nil, nil, fmt.Errorf("error parsing current version %q: %w", version, err)
}
uri := fmt.Sprintf(
"%s/service/rest/repository/browse/%s/%s/%s/",
c.NexusURL,
c.NexusRepository,
productName,
applicationName,
)
logger.WithField("uri", uri).Debug("fetching versions")
req, err := http.NewRequestWithContext(ctx, http.MethodGet, uri, http.NoBody)
if err != nil {
return current, nil, fmt.Errorf("error creating version check request: %w", err)
}
req.Header.Set("Accept", "text/html,application/xhtml+xml,application/xml")
req.Header.Set("Cache-Control", "no-cache")
req.Header.Set("User-Agent", fmt.Sprintf("%s/%s/%s", productName, applicationName, version))
resp, err := c.client.Do(req)
if err != nil {
return current, nil, fmt.Errorf("error performing version check request: %w", err)
}
defer resp.Body.Close()
data, err := io.ReadAll(resp.Body)
if err != nil {
return current, nil, fmt.Errorf("error parsing version response: %w", err)
}
logger.WithFields(log.Fields{
"status": resp.Status,
"body": string(data),
}).Debug("version check response")
if resp.StatusCode != http.StatusOK {
return current, nil, fmt.Errorf("error performing version check request: %s", resp.Status)
}
for _, result := range reVersion.FindAllStringSubmatch(string(data), -1) {
v, err := semver.NewVersion(result[1])
if err != nil {
logger.WithError(err).WithField("version", result[1]).Warn("error parsing version, skipping")
continue
}
if !c.AllowPrerelease && v.Prerelease() != "" {
logger.WithField("version", result[1]).Debug("skipping prerelease version")
continue
}
logger.WithField("version", v.String()).Debug("found version")
versions = append(versions, v)
}
if len(versions) == 0 {
return current, nil, fmt.Errorf("no versions found")
}
sort.Sort(sort.Reverse(semver.Collection(versions)))
return current, versions, nil
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment