Skip to content

Instantly share code, notes, and snippets.

@klingtnet
Last active October 6, 2020 21:00
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save klingtnet/b66ecace3e87b10972245fec7e4c3fc5 to your computer and use it in GitHub Desktop.
Save klingtnet/b66ecace3e87b10972245fec7e4c3fc5 to your computer and use it in GitHub Desktop.
package main
import (
"bytes"
"fmt"
"io"
"io/ioutil"
nurl "net/url"
"os"
"path/filepath"
"strings"
"github.com/golang-migrate/migrate/v4/source"
)
func init() {
source.Register("embed", &EmbedDriver{})
}
// A database migration driver for https://github.com/golang-migrate/migrate/tree/master/source
type EmbedDriver struct {
migrations *source.Migrations
basePath string
}
func (ed *EmbedDriver) Open(url string) (source.Driver, error) {
u, err := nurl.Parse(url)
if err != nil {
return nil, err
}
if u.Scheme != "embed" {
return nil, fmt.Errorf("incompatible scheme %q", u.Scheme)
}
ms := source.NewMigrations()
basePath := u.Host + u.Path
for _, file := range EmbeddedFiles() {
if strings.HasPrefix(file, basePath) {
file = strings.TrimPrefix(strings.TrimPrefix(file, basePath), "/")
m, err := source.DefaultParse(file)
if err != nil {
return nil, err
}
ok := ms.Append(m)
if !ok {
return nil, source.ErrDuplicateMigration{
Migration: *m,
}
}
}
}
return &EmbedDriver{
basePath: basePath,
migrations: ms,
}, nil
}
func (ed *EmbedDriver) Close() error {
return nil
}
func (ed *EmbedDriver) First() (version uint, err error) {
if version, ok := ed.migrations.First(); ok {
return version, nil
}
return 0, os.ErrNotExist
}
func (ed *EmbedDriver) Prev(version uint) (prevVersion uint, err error) {
if version, ok := ed.migrations.Prev(version); ok {
return version, nil
}
return 0, os.ErrNotExist
}
func (ed *EmbedDriver) Next(version uint) (nextVersion uint, err error) {
if version, ok := ed.migrations.Next(version); ok {
return version, nil
}
return 0, os.ErrNotExist
}
func (ed *EmbedDriver) ReadUp(version uint) (r io.ReadCloser, identifier string, err error) {
if m, ok := ed.migrations.Up(version); ok {
path := filepath.Join(ed.basePath, m.Raw)
return ioutil.NopCloser(bytes.NewBuffer(EmbeddedFile(path))), m.Identifier, nil
}
return nil, "", os.ErrNotExist
}
func (ed *EmbedDriver) ReadDown(version uint) (r io.ReadCloser, identifier string, err error) {
if m, ok := ed.migrations.Down(version); ok {
path := filepath.Join(ed.basePath, m.Raw)
return ioutil.NopCloser(bytes.NewBuffer(EmbeddedFile(path))), m.Identifier, nil
}
return nil, "", os.ErrNotExist
}
package main
import (
"bytes"
"encoding/base64"
"fmt"
"go/format"
"io/ioutil"
"log"
"os"
"path/filepath"
"text/template"
"github.com/urfave/cli"
)
func pathToVar(path string) string {
return fmt.Sprintf("file%x", []byte(path))
}
func encodeFile(data []byte) string {
return base64.RawStdEncoding.EncodeToString(data)
}
var (
fileTemplate = template.Must(template.New("").Funcs(template.FuncMap{"pathToVar": pathToVar, "encode": encodeFile}).Parse(`package {{ .Package }}
import (
"encoding/base64"
"sort"
)
const (
{{- range $path, $data := .Files }}
{{ pathToVar $path }} = "{{ encode $data }}"
{{- end }}
)
var embedMap = map[string]string{
{{- range $path, $_ := .Files }}
"{{ $path }}": {{ pathToVar $path }},
{{- end }}
}
// EmbeddedFiles returns an alphabetically sorted list of the embedded files.
func EmbeddedFiles() []string {
var fs []string
for f := range embedMap {
fs = append(fs,f)
}
sort.Strings(fs)
return fs
}
// EmbeddedFile returns the content of the file embedded as path.
// The function will panic if the content is not properly encoded.
func EmbeddedFile(path string) []byte {
e, ok := embedMap[path]
if !ok {
return nil
}
d, err := base64.RawStdEncoding.DecodeString(e)
if err != nil {
panic(err)
}
return d
}
// EmbeddedFileString is a convenience function
// that works like EmbeddedFile but returns a string
// instead of a byte slice.
func EmbeddedFileString(path string) string {
return string(EmbeddedFile(path))
}
`))
)
func readFile(path string) (data []byte, err error) {
f, err := os.Open(path)
if err != nil {
return
}
defer f.Close()
data, err = ioutil.ReadAll(f)
return
}
func embed(c *cli.Context) error {
files := make(map[string][]byte)
for _, includePath := range c.StringSlice("include") {
info, err := os.Stat(includePath)
if err != nil {
return fmt.Errorf("stat: %w", err)
}
if info.IsDir() {
walkFn := func(path string, info os.FileInfo, err error) error {
if err != nil {
return err
}
if info.IsDir() {
return nil
}
data, err := readFile(path)
if err != nil {
return fmt.Errorf("readFile: %w", err)
}
files[path] = data
return nil
}
err = filepath.Walk(includePath, walkFn)
if err != nil {
return fmt.Errorf("filepath.Walk: %w", err)
}
} else {
data, err := readFile(includePath)
if err != nil {
return fmt.Errorf("readFile: %w", err)
}
files[includePath] = data
}
}
templateData := struct {
Package string
Files map[string][]byte
}{
Package: c.String("package"),
Files: files,
}
buf := bytes.NewBuffer(nil)
err := fileTemplate.Execute(buf, templateData)
if err != nil {
return fmt.Errorf("fileTemplate.Execute: %w", err)
}
source, err := format.Source(buf.Bytes())
if err != nil {
return fmt.Errorf("format.Source: %w", err)
}
dest, err := os.OpenFile(c.String("destination"), os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0600)
if err != nil {
return fmt.Errorf("os.OpenFile: %w", err)
}
defer dest.Close()
_, err = dest.Write(source)
if err != nil {
return fmt.Errorf("dest.Write: %w", err)
}
return nil
}
func main() {
app := cli.App{
Name: "embed",
Flags: []cli.Flag{
&cli.StringFlag{
Name: "package",
Usage: "name of the package the generated Go file is associated to",
Value: "main",
},
&cli.StringFlag{
Name: "destination",
Usage: "where to store the generated Go file",
Value: "embeds.go",
},
&cli.StringSliceFlag{
Name: "include",
Usage: "paths to embed, directories are stored recursively (can be used multiple times)",
Required: true,
},
},
Action: embed,
}
err := app.Run(os.Args)
if err != nil {
log.Fatal(err)
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment