Skip to content

Instantly share code, notes, and snippets.

@bruth
Created June 19, 2017 16:59
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save bruth/b63b5c48df3007dd7aeee42de09f58a2 to your computer and use it in GitHub Desktop.
Save bruth/b63b5c48df3007dd7aeee42de09f58a2 to your computer and use it in GitHub Desktop.

NATS-RPC Generator

go generate command for creating a client interface, CLI, and serve function for a service interface.

Usage

//go:generate nats-rpc -type=Service -client=client.go -cli=./cmd/cli/main.go
package main

type Req struct {
  Left int
  Right int
}

type Resp struct {
  Sum int
}

type Service interface {
  Add(context.Context, *Req) (*Rep, error)
}

Options

  • type - Name of the service interface type. All methods are expected to have the same signature. (context.Context, <RequestType>) (<ResponseType>, error) where the request and response types can be user-defined.
  • client - Name of the output file to write the client type and serve function.
  • cli - Name of the output file to write the CLI.
  • group - Name of the NATS queue group for the serve subscription handlers. Defaults to svc.<pkg-name>.
  • prefix - Prefix to all NATS subjects used. Defaults to no prefix.
//go:generate nats-rpc -type=Service -client=client.go -cli=./cmd/cli/main.go
package example
import "context"
type Service interface {
Add(context.Context, *Req) (*Rep, error)
}
syntax = "proto3";
package example;
message Req {
int32 left = 1;
int32 right = 2;
}
message Rep {
int32 sum = 1;
}
package main
import (
"fmt"
"go/ast"
"go/build"
"go/importer"
"go/parser"
"go/token"
"go/types"
"path/filepath"
"strings"
)
func defaultImporter() types.Importer {
return importer.Default()
}
// prefixDirectory places the directory name on the beginning of each name in the list.
func prefixDirectory(directory string, names []string) []string {
if directory == "." {
return names
}
ret := make([]string, len(names))
for i, name := range names {
ret[i] = filepath.Join(directory, name)
}
return ret
}
// File holds a single parsed file and associated data.
type File struct {
pkg *Package
// Parsed AST.
file *ast.File
}
type Package struct {
dir string
name string
files []*File
// objects defined in the AST.
defs map[*ast.Ident]types.Object
typesPkg *types.Package
}
// check type-checks the package. The package must be OK to proceed.
func (p *Package) check(fs *token.FileSet, astFiles []*ast.File) error {
p.defs = make(map[*ast.Ident]types.Object)
config := types.Config{Importer: defaultImporter(), FakeImportC: true}
info := &types.Info{
Defs: p.defs,
}
typesPkg, err := config.Check(p.dir, fs, astFiles, info)
if err != nil {
return err
}
p.typesPkg = typesPkg
return nil
}
// ParsePackageDir parses the package residing in the directory.
func ParsePackageDir(d string) (*Package, error) {
pkg, err := build.Default.ImportDir(d, 0)
if err != nil {
return nil, fmt.Errorf("cannot process directory %s: %s", d, err)
}
var names []string
names = append(names, pkg.GoFiles...)
names = prefixDirectory(d, names)
return parsePackage(d, names, nil)
}
// parsePackage analyzes the single package constructed from the named files.
// If text is non-nil, it is a string to be used instead of the content of the file,
// to be used for testing. parsePackage exits if there is an error.
func parsePackage(directory string, names []string, text interface{}) (*Package, error) {
var (
pkg Package
astFiles []*ast.File
)
fs := token.NewFileSet()
for _, name := range names {
if !strings.HasSuffix(name, ".go") {
continue
}
parsedFile, err := parser.ParseFile(fs, name, text, 0)
if err != nil {
return nil, err
}
astFiles = append(astFiles, parsedFile)
pkg.files = append(pkg.files, &File{
file: parsedFile,
pkg: &pkg,
})
}
if len(astFiles) == 0 {
return nil, fmt.Errorf("%s: no buildable Go files", directory)
}
pkg.name = astFiles[0].Name.Name
pkg.dir = directory
// Type check the package.
err := pkg.check(fs, astFiles)
if err != nil {
return nil, err
}
return &pkg, nil
}
package main
import (
"bytes"
"flag"
"fmt"
"go/format"
"go/types"
"io/ioutil"
"log"
"text/template"
)
func init() {
log.SetFlags(0)
log.SetPrefix("nats-rpc: ")
}
func main() {
var (
typeName string
fileName string
cliFileName string
serviceGroup string
subjectPrefix string
)
flag.StringVar(&typeName, "type", "", "Type name.")
flag.StringVar(&fileName, "client", "", "Output file name client interface.")
flag.StringVar(&cliFileName, "cli", "", "Output file name for CLI.")
flag.StringVar(&serviceGroup, "group", "", "Name of the NATS queue group.")
flag.StringVar(&subjectPrefix, "prefix", "", "Prefix to all subjects.")
flag.Parse()
if typeName == "" {
log.Fatal("type required")
}
if fileName == "" {
log.Fatal("file name required")
}
if cliFileName == "" {
log.Fatal("cli file name required")
}
args := flag.Args()
// Default to current directory.
if len(args) == 0 {
args = []string{"."}
}
pkg, err := ParsePackageDir(args[0])
if err != nil {
log.Fatal(err)
}
var (
ok bool
obj types.Object
inf *types.Interface
)
for _, obj = range pkg.defs {
if obj == nil {
continue
}
// Ignore objects that don't have the target name.
if obj.Name() != typeName {
continue
}
// Looking for an interface type..
inf, ok = obj.Type().Underlying().(*types.Interface)
if !ok {
continue
}
break
}
meta := reflectInterface(inf)
meta.Name = typeName
meta.Pkg = obj.Pkg().Name()
if serviceGroup == "" {
serviceGroup = fmt.Sprintf("%#v", fmt.Sprintf("%s.svc", meta.Pkg))
}
for _, m := range meta.Methods {
m.Pkg = meta.Pkg
m.Topic = fmt.Sprintf("%#v", fmt.Sprintf("%s%s.%s", subjectPrefix, meta.Pkg, m.Name))
m.ServiceGroup = serviceGroup
}
// Compile and generate files.
var buf bytes.Buffer
t := template.Must(template.New("client").Parse(fileTmpl))
if err := t.Execute(&buf, meta); err != nil {
log.Fatal(err)
}
// Format the output.
src, err := format.Source(buf.Bytes())
if err != nil {
log.Fatal(err)
}
if err = ioutil.WriteFile(fileName, src, 0644); err != nil {
log.Fatalf("writing output: %s", err)
}
// Reuse buffer.
buf.Reset()
t = template.Must(template.New("cli").Parse(cliTmpl))
if err := t.Execute(&buf, meta); err != nil {
log.Fatal(err)
}
// Format the output.
src, err = format.Source(buf.Bytes())
if err != nil {
log.Fatal(err)
}
if err = ioutil.WriteFile(cliFileName, src, 0644); err != nil {
log.Fatalf("writing output: %s", err)
}
}
package main
import (
"go/types"
"log"
)
type Interface struct {
Pkg string
Name string
Methods []*Method
}
type Method struct {
Pkg string
Name string
Topic string
Request *Var
Response *Var
ServiceGroup string
ins []*Var
outs []*Var
}
type Var struct {
Pkg string
Type string
Ptr bool
}
func reflectInterface(iface *types.Interface) *Interface {
var x Interface
// Method count.
nm := iface.NumMethods()
x.Methods = make([]*Method, nm)
for i := 0; i < nm; i++ {
m := iface.Method(i)
x.Methods[i] = reflectMethod(m)
}
return &x
}
func reflectMethod(m *types.Func) *Method {
sig := m.Type().(*types.Signature)
params := sig.Params()
results := sig.Results()
if params.Len() != 2 {
log.Fatalf("expected 2 params, got %d", params.Len())
}
if results.Len() != 2 {
log.Fatalf("expected 2 results, got %d", results.Len())
}
x := Method{
Name: m.Name(),
}
x.Request = reflectVar(params.At(1))
x.Response = reflectVar(results.At(0))
return &x
}
func reflectVar(v *types.Var) *Var {
var x Var
t := v.Type()
switch u := t.(type) {
case *types.Named:
o := u.Obj()
x.Type = o.Name()
p := o.Pkg()
if p != nil {
x.Pkg = p.Name()
}
case *types.Pointer:
x.Ptr = true
o := u.Elem().(*types.Named).Obj()
x.Type = o.Name()
p := o.Pkg()
if p != nil {
x.Pkg = p.Name()
}
}
return &x
}
package main
var fileTmpl = `// Generated by nats-rpc. DO NOT EDIT.
package {{ .Pkg }}
import (
"context"
"os"
"os/signal"
"syscall"
"github.com/golang/protobuf/proto"
"github.research.chop.edu/libi/transport"
)
var (
traceIdKey = struct{}{}
)
type Client interface {
{{ .Name }}
}
type client struct {
tp transport.Transport
}
{{ range .Methods }}
func (c *client) {{ .Name }}(ctx context.Context, req *{{ .Request.Type }}) (*{{ .Response.Type }}, error) {
var rep {{ .Response.Type }}
_, err := c.tp.Request({{ .Topic }}, req, &rep)
if err != nil {
return nil, err
}
return &rep, nil
}
{{ end }}
func NewClient(tp transport.Transport) Client {
return &client{tp}
}
func Serve(ctx context.Context, tp transport.Transport, svc Service) error {
ctx, cancel := context.WithCancel(ctx)
defer func() {
cancel()
}()
var err error
{{ range .Methods }}
_, err = tp.Subscribe({{ .Topic }}, func(msg *transport.Message) (proto.Message, error) {
ctx := context.WithValue(ctx, traceIdKey, msg.Id)
var req {{ .Request.Type }}
if err := msg.Decode(&req); err != nil {
return nil, err
}
return svc.{{ .Name }}(ctx, &req)
}, transport.SubscribeQueue({{ .ServiceGroup }}))
if err != nil {
return err
}
{{ end }}
sigchan := make(chan os.Signal)
signal.Notify(sigchan, syscall.SIGINT, syscall.SIGTERM)
<-sigchan
return nil
}
`
var cliTmpl = `// Generated by nats-rpc. DO NOT EDIT.
package main
import (
"bytes"
"context"
"flag"
"fmt"
"os"
"github.research.chop.edu/libi/log"
"github.research.chop.edu/libi/{{ .Pkg }}"
"github.research.chop.edu/libi/transport"
"go.uber.org/zap"
"github.com/golang/protobuf/proto"
"github.com/golang/protobuf/jsonpb"
"github.com/nats-io/go-nats"
)
const (
clientType = "{{ .Pkg }}-cli"
)
var (
buildVersion string
traceIdKey = struct{}{}
jsonMarshaler = &jsonpb.Marshaler{
EmitDefaults: true,
}
jsonUnmarshaler = &jsonpb.Unmarshaler{}
)
func main() {
var (
natsAddr string
printVersion bool
)
flag.StringVar(&natsAddr, "nats.addr", "nats://127.0.0.1:4222", "NATS address.")
flag.BoolVar(&printVersion, "version", false, "Print version.")
flag.Parse()
if printVersion {
fmt.Fprintln(os.Stdout, buildVersion)
return
}
// Get method.
args := flag.Args()
if len(args) == 0 {
log.Fatalf("method name required")
}
meth := args[0]
// Initialize base logger.
logger, err := log.New()
if err != nil {
log.Fatal(err)
}
logger = logger.With(
zap.String("client.type", clientType),
zap.String("client.version", buildVersion),
)
// Initialize the transport layer.
tp, err := transport.Connect(&nats.Options{
Url: natsAddr,
})
if err != nil {
log.Fatal(err)
}
defer tp.Close()
tp.SetLogger(logger)
inp := "{}"
if len(args) > 1 {
inp = args[1]
}
inpr := bytes.NewBufferString(inp)
client := {{ .Pkg }}.NewClient(tp)
var rep proto.Message
ctx := context.Background()
switch meth { {{ range .Methods }}
case "{{ .Name }}":
var req {{ .Pkg }}.{{ .Request.Type }}
if err := jsonUnmarshaler.Unmarshal(inpr, &req); err != nil {
log.Fatalf("json: %s", err)
}
rep, err = client.{{ .Name }}(ctx, &req)
{{ end }}
default:
log.Fatalf("unknown method %s", meth)
}
if err != nil {
log.Fatalf("rpc error: %s", err)
}
if err := jsonMarshaler.Marshal(os.Stdout, rep); err != nil {
log.Fatalf("error encoding response: %s", err)
}
fmt.Fprint(os.Stdout, "\n")
}
`
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment