Skip to content

Instantly share code, notes, and snippets.

@skipor
Last active March 20, 2019 14:50
Show Gist options
  • Save skipor/7ca81670261b949abd001613ed6898ac to your computer and use it in GitHub Desktop.
Save skipor/7ca81670261b949abd001613ed6898ac to your computer and use it in GitHub Desktop.
CodeWriter helper for generation Go code.
// Copyright (c) 2019 Yandex LLC.
// Author: Dmitry Novikov <novikoff@yandex-team.ru>
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in all
// copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
// SOFTWARE.
package gen
import (
"bytes"
"fmt"
"io/ioutil"
"path"
"path/filepath"
"strings"
"golang.org/x/tools/imports"
)
type ImportInfo struct {
alias string
path string
t importType
}
func (i *ImportInfo) Path() string {
return i.path
}
func (i *ImportInfo) Alias() string {
return i.alias
}
func (i *ImportInfo) PackageName() string {
x := i.Alias()
if x != "" {
return x
}
return path.Base(i.Path())
}
type importType int
const (
internal importType = iota
external
standard
)
type WriterOption func(*ImportInfo)
func Alias(a string) WriterOption {
return func(x *ImportInfo) {
x.alias = a
}
}
type SavePoint int
// CodeWriter is a helper for code generation.
type CodeWriter struct {
noLint bool
packageName string
// internalImportPrefix is your repository package path prefix.
// Imports with such prefix will be put in separate block.
internalImportPrefix string
comment string
imports map[string]*ImportInfo
buf bytes.Buffer
}
func (w *CodeWriter) SavePoint() SavePoint {
return SavePoint(w.buf.Len())
}
func (w *CodeWriter) Rollback(p SavePoint) {
w.buf.Truncate(int(p))
}
func (w *CodeWriter) SetPackageName(name string) {
w.packageName = name
}
func (w *CodeWriter) SetComment(x string) {
w.comment = x
}
func (w *CodeWriter) SetInternalImportPrefix(prefix string) {
w.internalImportPrefix = prefix
}
func (w *CodeWriter) SetNoLint() {
w.noLint = true
}
func (w *CodeWriter) GetImport(path string) *ImportInfo {
return w.imports[path]
}
func (w *CodeWriter) AddImport(path string, opts ...WriterOption) *ImportInfo {
x := &ImportInfo{
path: path,
t: w.detectImportType(path),
}
for _, v := range opts {
v(x)
}
if w.imports == nil {
w.imports = map[string]*ImportInfo{}
} else if i := w.GetImport(path); i != nil {
return i
}
w.imports[path] = x
return x
}
// NL adds a new line.
func (w *CodeWriter) NL() *CodeWriter {
fmt.Fprintln(&w.buf)
return w
}
// Quoted adds a double quoted string. Used for import writing.
func (w *CodeWriter) Quoted(x string) *CodeWriter {
return w.Line(`"` + x + `"`)
}
// Line adds text with a new line at the end.
func (w *CodeWriter) Line(format string, args ...interface{}) *CodeWriter {
return w.Write(format, args...).NL()
}
// Line adds text with no new line.
func (w *CodeWriter) Write(format string, args ...interface{}) *CodeWriter {
fmt.Fprintf(&w.buf, format, args...)
return w
}
// Format returns generated code, formated with goimports.
func (w *CodeWriter) Format() ([]byte, error) {
output := &bytes.Buffer{}
fmt.Fprintln(output, "// "+w.comment)
fmt.Fprintln(output)
if w.noLint {
fmt.Fprintln(output)
fmt.Fprintln(output, "//nolint")
}
fmt.Fprintln(output, "package "+w.packageName)
fmt.Fprintln(output)
if len(w.imports) > 0 {
fmt.Fprintln(output, "import (")
write := func(t importType) {
for _, v := range w.imports {
if t == v.t {
if v.alias != "" {
fmt.Fprint(output, v.alias+" ")
}
fmt.Fprintln(output, `"`+v.path+`"`)
}
}
fmt.Fprintln(output)
}
write(standard)
write(external)
write(internal)
fmt.Fprintln(output, ")")
fmt.Fprintln(output)
}
if w.noLint {
fmt.Fprintln(output)
fmt.Fprintln(output, "//revive:disable")
fmt.Fprintln(output)
}
output.Write(w.buf.Bytes())
result, err := imports.Process("", output.Bytes(), nil)
if err != nil {
return output.Bytes(), err
}
return result, nil
}
// WriteFile writes formatted file to disk
func (w *CodeWriter) WriteFile(path ...string) error {
result, ferr := w.Format()
filename := filepath.Join(path...)
err := ioutil.WriteFile(filename, result, 0644)
if err != nil {
return err
}
if ferr != nil {
return fmt.Errorf("error formatting file %v: %v", filename, ferr)
}
return nil
}
func (w CodeWriter) detectImportType(path string) importType {
if strings.HasPrefix(path, w.internalImportPrefix) {
return internal
}
if strings.Contains(path, ".") {
return external
}
return standard
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment