Skip to content

Instantly share code, notes, and snippets.

@dancompton
Last active June 27, 2016 22:28
Show Gist options
  • Save dancompton/6224c03251686d5fe10a655416a1a7ea to your computer and use it in GitHub Desktop.
Save dancompton/6224c03251686d5fe10a655416a1a7ea to your computer and use it in GitHub Desktop.
gogo protobuf plugin that generates helper like funcs like Uint64() for bitflag message types (iow messages consisting of only bool fields)
/*
Given the protobuf below, generates the following funcs so that the message can be used as bitflags:
func (this *User) UInt64() uint64 {
b := uint64(0)
if this.ScopeA {
b |= uint64(1) << uint64(0)
}
if this.ScopeB {
b |= uint64(1) << uint64(1)
}
if this.ScopeC {
b |= uint64(1) << uint64(2)
}
return b
}
func (this *User) HighFlags() []string {
var b []string
if this.ScopeA {
b = append(b, "scope_a")
}
if this.ScopeB {
b = append(b, "scope_b")
}
if this.ScopeC {
b = append(b, "scope_c")
}
return b
}
func (this *User) LowFlags() []string {
var b []string
if !this.ScopeA {
b = append(b, "scope_a")
}
if !this.ScopeB {
b = append(b, "scope_b")
}
if !this.ScopeC {
b = append(b, "scope_c")
}
return b
}
func (this *User) SetFlag(flag string) error {
switch flag {
case "scope_a":
this.ScopeA = true
case "scope_b":
this.ScopeB = true
case "scope_c":
this.ScopeC = true
default:
return fmt.Errorf("invalid flag: %v", flag)
}
return nil
}
func (this *User) ClearFlag(flag string) error {
switch flag {
case "scope_a":
this.ScopeA = false
case "scope_b":
this.ScopeB = false
case "scope_c":
this.ScopeC = false
default:
return fmt.Errorf("invalid flag: %v", flag)
}
return nil
}
func (this *User) SetFlags(flags ...string) []error {
var errs []error
for _, f := range flags {
if err := this.SetFlag(f); err != nil {
errs = append(errs, err)
}
}
return errs
}
func (this *User) ClearFlags(flags ...string) []error {
var errs []error
for _, f := range flags {
if err := this.ClearFlag(f); err != nil {
errs = append(errs, err)
}
}
return errs
}
func (this *User) TestFlag(flag string) bool {
switch flag {
case "scope_a":
return this.ScopeA
case "scope_b":
return this.ScopeB
case "scope_c":
return this.ScopeC
}
return false
}
func (this *User) TestFlags(flags ...string) bool {
for _, f := range flags {
if !this.TestFlag(f) {
return false
}
}
return true
}
func (this *User) Scan(i interface{}) error {
switch v := i.(type) {
case int:
return this.FromUInt64(uint64(v))
case int32:
return this.FromUInt64(uint64(v))
case float32:
return this.FromUInt64(uint64(v))
case float64:
return this.FromUInt64(uint64(v))
}
return fmt.Errorf("invalid type: %T", i)
}
func (this *User) FromUInt64(b uint64) error {
bb := b
bb = b
if bb&(uint64(1)<<uint(0)) > 0 {
this.ScopeA = true
} else {
this.ScopeA = false
}
bb = b
if bb&(uint64(1)<<uint(1)) > 0 {
this.ScopeB = true
} else {
this.ScopeB = false
}
bb = b
if bb&(uint64(1)<<uint(2)) > 0 {
this.ScopeC = true
} else {
this.ScopeC = false
}
return nil
}
*/
// Example protobuf
/*
syntax = "proto3";
import "github.com/gogo/protobuf/gogoproto/gogo.proto";
import "github.com/bitflags/bitflag.proto";
package flavortown.flags;
message User {
option (bitflagproto.bitflags) = true;
bool scopeA = 1;
bool scopeB = 2;
bool scopeC = 3;
}
*/
package bitflags
import (
"bytes"
"unicode"
"github.com/gogo/protobuf/protoc-gen-gogo/generator"
"github.com/bitflags/protobuf/bitflagproto"
)
type plugin struct {
*generator.Generator
generator.PluginImports
messages []*generator.Descriptor
}
func NewBitflags() *plugin {
return &plugin{}
}
func (p *plugin) Name() string {
return "bitflags"
}
func (p *plugin) Init(g *generator.Generator) {
p.Generator = g
}
func (p *plugin) Generate(file *generator.FileDescriptor) {
p.PluginImports = generator.NewPluginImports(p.Generator)
p.messages = make([]*generator.Descriptor, 0)
for _, message := range file.Messages() {
if !bitflagproto.IsBitflags(file.FileDescriptorProto, message.DescriptorProto) {
continue
}
p.messages = append(p.messages, message)
baseTypeName := generator.CamelCaseSlice(message.TypeName())
// UInt64()
// returns a bitflags uint64 representation of the structure
p.P(`func (this *`, baseTypeName, `) UInt64() uint64 {`)
p.In()
p.P(`b := uint64(0)`)
for bit, field := range message.Field {
fieldname := p.GetFieldName(message, field)
p.P(`if this.`, fieldname, ` {`)
p.In()
p.P(`b |= uint64(1) << uint64(`, bit, `)`)
p.Out()
p.P(`}`)
}
p.P()
p.P(`return b`)
p.P(`}`)
// HighFlags() returns fields in struct set to 1
p.P(`func (this *`, baseTypeName, `) HighFlags() []string {`)
p.In()
p.P(`var b []string`)
for _, field := range message.Field {
fieldname := p.GetFieldName(message, field)
p.P(`if this.`, fieldname, ` {`)
p.In()
p.P(`b = append(b, "`, snakeCase(fieldname), `")`)
p.Out()
p.P(`}`)
}
p.P(`return b`)
p.P(`}`)
p.P()
// LowFlags() returns fields in struct set to 0
p.P(`func (this *`, baseTypeName, `) LowFlags() []string {`)
p.In()
p.P(`var b []string`)
for _, field := range message.Field {
fieldname := p.GetFieldName(message, field)
p.P(`if !this.`, fieldname, ` {`)
p.In()
p.P(`b = append(b, "`, snakeCase(fieldname), `")`)
p.Out()
p.P(`}`)
}
p.P(`return b`)
p.P(`}`)
p.P()
// Sets a flag or returns error
p.P(`func (this *`, baseTypeName, `) SetFlag(flag string) error {`)
p.In()
p.P(`switch flag {`)
p.In()
for _, field := range message.Field {
fieldname := p.GetFieldName(message, field)
p.P(`case "`, snakeCase(fieldname), `":`)
p.In()
p.P(`this.`, fieldname, `= true `)
p.Out()
}
p.P(`default:`)
p.In()
p.P(`return fmt.Errorf("invalid flag: %v", flag)`)
p.Out()
p.P(`}`)
p.Out()
p.P(`return nil`)
p.P(`}`)
// Sets a flag or returns error
p.P(`func (this *`, baseTypeName, `) ClearFlag(flag string) error {`)
p.In()
p.P(`switch flag {`)
p.In()
for _, field := range message.Field {
fieldname := p.GetFieldName(message, field)
p.P(`case "`, snakeCase(fieldname), `":`)
p.In()
p.P(`this.`, fieldname, `= false `)
p.Out()
}
p.P(`default:`)
p.In()
p.P(`return fmt.Errorf("invalid flag: %v", flag)`)
p.Out()
p.P(`}`)
p.Out()
p.P(`return nil`)
p.P(`}`)
// SetFlags(...string) []error
// sets a number of flags which correspond to fields in the struct
p.P(`func (this *`, baseTypeName, `) SetFlags(flags ...string) []error {`)
p.In()
p.P(`var errs []error`)
p.P(`for _, f := range flags {`)
p.In()
p.P(`if err := this.SetFlag(f); err != nil {`)
p.In()
p.P(`errs = append(errs, err)`)
p.Out()
p.P(`}`)
p.Out()
p.P(`}`)
p.P(`return errs`)
p.Out()
p.P(`}`)
// ClearFlags(...string) []error
// sets a number of flags which correspond to fields in the struct
p.P(`func (this *`, baseTypeName, `) ClearFlags(flags ...string) []error {`)
p.In()
p.P(`var errs []error`)
p.P(`for _, f := range flags {`)
p.In()
p.P(`if err := this.ClearFlag(f); err != nil {`)
p.In()
p.P(`errs = append(errs, err)`)
p.Out()
p.P(`}`)
p.Out()
p.P(`}`)
p.P(`return errs`)
p.Out()
p.P(`}`)
// Returns the value of a flag or false if the flag does not exist
p.P(`func (this *`, baseTypeName, `) TestFlag(flag string) bool {`)
p.In()
p.P(`switch flag {`)
p.In()
for _, field := range message.Field {
fieldname := p.GetFieldName(message, field)
p.P(`case "`, snakeCase(fieldname), `":`)
p.In()
p.P(`return this.`, fieldname)
p.Out()
}
p.P(`}`)
p.Out()
p.P(`return false`)
p.P(`}`)
// TestFlags(...string) []error
// Returns Flag1 AND Flag2 AND ...
p.P(`func (this *`, baseTypeName, `) TestFlags(flags ...string) bool {`)
p.In()
p.P(`for _, f := range flags {`)
p.In()
p.P(`if !this.TestFlag(f) {`)
p.In()
p.P(`return false`)
p.Out()
p.P(`}`)
p.Out()
p.P(`}`)
p.P(`return true`)
p.Out()
p.P(`}`)
// Represent in the database as int64()
// returns a bitflags uint64 representation of the structure
p.P(`func (this *`, baseTypeName, `) Scan(i interface{}) error {`)
p.In()
p.P(`switch v := i.(type) {`)
types := []string{"int", "int32", "float32", "float64"}
for _, t := range types {
p.P(`case `, t, `:`)
p.In()
p.P(`return this.FromUInt64(uint64(v))`)
p.Out()
}
p.P(`}`)
p.P()
p.P(`return fmt.Errorf("invalid type: %T", i)`)
p.P(`}`)
// FromInt64(b uint64)
// TODO(dan) should return error if overflow
p.P(`func (this *`, baseTypeName, `) FromUInt64(b uint64) error {`)
p.In()
p.P(`bb := b`)
for i, field := range message.Field {
p.P(`bb = b`)
fieldname := p.GetFieldName(message, field)
p.P(`if bb&(uint64(1)<<uint(`, i, `)) > 0 {`)
p.In()
p.P(`this.`, fieldname, ` = true`)
p.Out()
p.P(`} else {`)
p.In()
p.P(`this.`, fieldname, ` = false`)
p.P(`}`)
p.Out()
}
p.P()
p.P(`return nil`)
p.P(`}`)
}
}
func snakeCase(in string) string {
runes := []rune(in)
length := len(runes)
out := bytes.NewBuffer(make([]byte, 0, length))
for i := 0; i < length; i++ {
if i > 0 && unicode.IsUpper(runes[i]) && ((i+1 < length && unicode.IsLower(runes[i+1])) || unicode.IsLower(runes[i-1])) {
out.WriteRune('_')
}
out.WriteRune(unicode.ToLower(runes[i]))
}
return out.String()
}
func init() {
generator.RegisterPlugin(NewBitflags())
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment