Skip to content

Instantly share code, notes, and snippets.

@hucsmn
Last active January 29, 2018 13:58
Show Gist options
  • Save hucsmn/c07330c5761e7fabdf74963bee85ecf2 to your computer and use it in GitHub Desktop.
Save hucsmn/c07330c5761e7fabdf74963bee85ecf2 to your computer and use it in GitHub Desktop.
go generate 工具,嵌入文本文件内容
// embed.go: go generate tool for embedding text files
//TODO: file preprocessors (using os/exec), concurrent workers
package main
import (
"bytes"
"flag"
"fmt"
"go/format"
"io"
"io/ioutil"
"os"
"path"
"path/filepath"
"regexp"
"strconv"
"strings"
"text/template"
"unicode"
"unicode/utf8"
)
const srcTmplOutput = `// code generated by: {{.Arguments}}
package {{.Package}}
{{if eq .Type "string"}}
const ({{range .Items}}
{{if ne .Comment ""}}// {{.Comment}}
{{end}}{{.Name}} = {{.Value | toString}}
{{end}}
)
{{else if eq .Type "bytes"}}
var ({{range .Items}}
{{if ne .Comment ""}}// {{.Comment}}
{{end}}{{.Name}} = {{.Value | toBytes}}
{{end}}
)
{{else}}
// invalid target type: {{.Type}}
{{end}}
`
var (
flagOut = flag.String("o", "out.embed.go", "the output file name (out.embed.go by default)")
flagPackage = flag.String("p", "", "package name (use output directory name by default)")
flagType = flag.String("t", "string", "target type: string (const string, default), bytes (var []byte)")
flagFormat = flag.String("f", "{{.Name | private}}", "naming format, default is {{.Name | private}}.")
flagQuiet = flag.Bool("q", false, "be quite")
flagHelp = flag.Bool("h", false, "show this help")
)
var (
rxIdentifier = regexp.MustCompile(`^\p{L}[\p{L}\p{N}]*$`)
rxWord = regexp.MustCompile(`[\p{L}\p{N}]+`)
outName string
outData OutputData
nameFormatter func(NamingData) (string, error)
tmplOutput *template.Template
)
type OutputData struct {
Arguments string
Package string
Type string
Items []OutputDataItem
}
type OutputDataItem struct {
Comment string
Name string
Value []byte
}
type NamingData struct {
Number int
Name string
Ext string
FileName string
DirName string
Package string
}
func init() {
outData.Arguments = strings.Join(os.Args, " ")
flag.Usage = func() {
fmt.Fprintln(os.Stderr, "usage: go run embed.go [options] input ...")
fmt.Fprintln(os.Stderr, "")
fmt.Fprintln(os.Stderr, "options:")
flag.VisitAll(func(flg *flag.Flag) {
fmt.Fprintf(os.Stderr, "\t-%s\t%s\n", flg.Name, flg.Usage)
})
}
flag.Parse()
if *flagHelp || flag.NArg() <= 0 {
flag.Usage()
os.Exit(2)
}
outName = *flagOut
pkg := *flagPackage
autopkg := false
if pkg == "" {
outdir := path.Dir(path.Clean(filepath.ToSlash(*flagOut)))
if outdir == "" || outdir == "." {
outdir, _ = filepath.Abs(".")
}
pkg = filepath.Base(outdir)
autopkg = true
}
if !(rxIdentifier.MatchString(pkg) && pkg != "_") {
if autopkg {
Abort("failed to choose package name by output directory: %s\n", pkg)
} else {
Abort("invalid package name: %s\n", pkg)
}
}
outData.Package = pkg
typ := *flagType
typ = strings.ToLower(typ)
switch typ {
case "string", "bytes":
default:
Abort("invalid target type: %s\n", typ)
}
outData.Type = typ
tmpl := template.New("naming")
tmpl.Funcs(map[string]interface{}{
"public": func(s string) (string, error) { return autoNaming(s, true) },
"private": func(s string) (string, error) { return autoNaming(s, false) },
})
tmplNaming, err := tmpl.Parse(*flagFormat)
if err != nil {
Abort("bad name format: %s\n", err.Error())
}
nameFormatter = func(data NamingData) (string, error) {
b, err := render(tmplNaming, data)
return string(b), err
}
tmpl = template.New("output")
tmpl.Funcs(map[string]interface{}{
"toString": toString,
"toBytes": toBytes,
})
tmplOutput, err = tmpl.Parse(srcTmplOutput)
if err != nil {
Abort("internal output template: %s\n", err.Error())
}
}
func main() {
var outFile io.WriteCloser
if outName != "-" {
fi, err := os.Stat(outName)
if !os.IsNotExist(err) {
if fi.IsDir() {
Abort("output file is a directory: %s\n", outName)
}
Warn("output file already exists: %s\n", outName)
}
outFile, err = os.Create(outName)
if err != nil {
Abort("create output file: %s\n", outName)
}
} else {
outFile = os.Stdout
}
defer outFile.Close()
for i, inName := range flag.Args() {
absName, _ := filepath.Abs(inName)
dirName := filepath.Base(filepath.Dir(absName))
baseName := path.Base(inName)
baseFields := strings.SplitN(strings.TrimLeft(baseName, "."), ".", 2)
name := baseFields[0]
ext := ""
if len(baseFields) == 2 {
ext = baseFields[1]
}
item := OutputDataItem{}
itemName, err := nameFormatter(NamingData{
Number: i,
Name: name,
Ext: ext,
FileName: baseName,
DirName: dirName,
Package: outData.Package,
})
if err != nil {
itemName = fmt.Sprintf("item%d", i)
msg := fmt.Sprintf("failed naming item %q, use item%d instead (%s). ", inName, i, err)
item.Comment += msg
Warn("%s\n", msg)
}
item.Name = itemName
itemValue, err := ioutil.ReadFile(inName)
if err != nil {
msg := fmt.Sprintf("failed to read %q (%s). ", inName, err)
item.Comment += msg
Warn("%s\n", msg)
} else {
item.Value = itemValue
}
outData.Items = append(outData.Items, item)
}
raw, err := render(tmplOutput, outData)
if err != nil {
Abort("render: %s\n", err)
}
outFmt, err := format.Source(raw)
if err != nil {
Abort("gofmt: %s\n", err)
}
_, err = io.Copy(outFile, bytes.NewReader(outFmt))
if err != nil {
Abort("write: %s\n", err)
}
}
func toString(bs []byte) string {
return strconv.Quote(string(bs))
}
func toBytes(bs []byte) string {
return fmt.Sprintf("%#v", bs)
}
func render(tmpl *template.Template, data interface{}) ([]byte, error) {
buf := bytes.NewBuffer(make([]byte, 0, 64))
err := tmpl.Execute(buf, data)
if err != nil {
return nil, err
}
return buf.Bytes(), nil
}
func autoNaming(name string, public bool) (string, error) {
s := strings.Replace(name, "_", " ", -1)
words := rxWord.FindAllString(s, -1)
if len(words) <= 0 {
return "", fmt.Errorf("failed to auto naming: %s", name)
}
if r, _ := utf8.DecodeRuneInString(words[0]); !unicode.IsLetter(r) {
words[0] = firstRuneToUpper(words[0])
tmp := make([]string, 1, len(words)+1)
tmp[0] = "x"
words = append(tmp, words...)
}
ret := ""
for i, t := range words {
if public || i > 0 {
t = firstRuneToUpper(t)
}
ret += t
}
return ret, nil
}
func firstRuneToUpper(s string) string {
buf := make([]byte, 4)
r, n := utf8.DecodeRuneInString(s)
return string(buf[:utf8.EncodeRune(buf, unicode.ToUpper(r))]) + s[n:]
}
func Abort(format string, v ...interface{}) {
fmt.Fprintf(os.Stderr, "error: ")
fmt.Fprintf(os.Stderr, format, v...)
os.Exit(1)
}
func Warn(format string, v ...interface{}) {
if !*flagQuiet {
fmt.Fprintf(os.Stderr, "warning: ")
fmt.Fprintf(os.Stderr, format, v...)
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment