Skip to content

Instantly share code, notes, and snippets.

@dan-compton dan-compton/bitflags.go
Last active Jun 27, 2016

Embed
What would you like to do?
gogo protobuf plugin that generates helper like funcs like Uint64() for bitflag message types (iow messages consisting of only bool fields)
/*
Given the protobuf below, generates the following funcs so that the message can be used as bitflags:
func (this *User) UInt64() uint64 {
b := uint64(0)
if this.ScopeA {
b |= uint64(1) << uint64(0)
}
if this.ScopeB {
b |= uint64(1) << uint64(1)
}
if this.ScopeC {
b |= uint64(1) << uint64(2)
}
return b
}
func (this *User) HighFlags() []string {
var b []string
if this.ScopeA {
b = append(b, "scope_a")
}
if this.ScopeB {
b = append(b, "scope_b")
}
if this.ScopeC {
b = append(b, "scope_c")
}
return b
}
func (this *User) LowFlags() []string {
var b []string
if !this.ScopeA {
b = append(b, "scope_a")
}
if !this.ScopeB {
b = append(b, "scope_b")
}
if !this.ScopeC {
b = append(b, "scope_c")
}
return b
}
func (this *User) SetFlag(flag string) error {
switch flag {
case "scope_a":
this.ScopeA = true
case "scope_b":
this.ScopeB = true
case "scope_c":
this.ScopeC = true
default:
return fmt.Errorf("invalid flag: %v", flag)
}
return nil
}
func (this *User) ClearFlag(flag string) error {
switch flag {
case "scope_a":
this.ScopeA = false
case "scope_b":
this.ScopeB = false
case "scope_c":
this.ScopeC = false
default:
return fmt.Errorf("invalid flag: %v", flag)
}
return nil
}
func (this *User) SetFlags(flags ...string) []error {
var errs []error
for _, f := range flags {
if err := this.SetFlag(f); err != nil {
errs = append(errs, err)
}
}
return errs
}
func (this *User) ClearFlags(flags ...string) []error {
var errs []error
for _, f := range flags {
if err := this.ClearFlag(f); err != nil {
errs = append(errs, err)
}
}
return errs
}
func (this *User) TestFlag(flag string) bool {
switch flag {
case "scope_a":
return this.ScopeA
case "scope_b":
return this.ScopeB
case "scope_c":
return this.ScopeC
}
return false
}
func (this *User) TestFlags(flags ...string) bool {
for _, f := range flags {
if !this.TestFlag(f) {
return false
}
}
return true
}
func (this *User) Scan(i interface{}) error {
switch v := i.(type) {
case int:
return this.FromUInt64(uint64(v))
case int32:
return this.FromUInt64(uint64(v))
case float32:
return this.FromUInt64(uint64(v))
case float64:
return this.FromUInt64(uint64(v))
}
return fmt.Errorf("invalid type: %T", i)
}
func (this *User) FromUInt64(b uint64) error {
bb := b
bb = b
if bb&(uint64(1)<<uint(0)) > 0 {
this.ScopeA = true
} else {
this.ScopeA = false
}
bb = b
if bb&(uint64(1)<<uint(1)) > 0 {
this.ScopeB = true
} else {
this.ScopeB = false
}
bb = b
if bb&(uint64(1)<<uint(2)) > 0 {
this.ScopeC = true
} else {
this.ScopeC = false
}
return nil
}
*/
// Example protobuf
/*
syntax = "proto3";
import "github.com/gogo/protobuf/gogoproto/gogo.proto";
import "github.com/bitflags/bitflag.proto";
package flavortown.flags;
message User {
option (bitflagproto.bitflags) = true;
bool scopeA = 1;
bool scopeB = 2;
bool scopeC = 3;
}
*/
package bitflags
import (
"bytes"
"unicode"
"github.com/gogo/protobuf/protoc-gen-gogo/generator"
"github.com/bitflags/protobuf/bitflagproto"
)
type plugin struct {
*generator.Generator
generator.PluginImports
messages []*generator.Descriptor
}
func NewBitflags() *plugin {
return &plugin{}
}
func (p *plugin) Name() string {
return "bitflags"
}
func (p *plugin) Init(g *generator.Generator) {
p.Generator = g
}
func (p *plugin) Generate(file *generator.FileDescriptor) {
p.PluginImports = generator.NewPluginImports(p.Generator)
p.messages = make([]*generator.Descriptor, 0)
for _, message := range file.Messages() {
if !bitflagproto.IsBitflags(file.FileDescriptorProto, message.DescriptorProto) {
continue
}
p.messages = append(p.messages, message)
baseTypeName := generator.CamelCaseSlice(message.TypeName())
// UInt64()
// returns a bitflags uint64 representation of the structure
p.P(`func (this *`, baseTypeName, `) UInt64() uint64 {`)
p.In()
p.P(`b := uint64(0)`)
for bit, field := range message.Field {
fieldname := p.GetFieldName(message, field)
p.P(`if this.`, fieldname, ` {`)
p.In()
p.P(`b |= uint64(1) << uint64(`, bit, `)`)
p.Out()
p.P(`}`)
}
p.P()
p.P(`return b`)
p.P(`}`)
// HighFlags() returns fields in struct set to 1
p.P(`func (this *`, baseTypeName, `) HighFlags() []string {`)
p.In()
p.P(`var b []string`)
for _, field := range message.Field {
fieldname := p.GetFieldName(message, field)
p.P(`if this.`, fieldname, ` {`)
p.In()
p.P(`b = append(b, "`, snakeCase(fieldname), `")`)
p.Out()
p.P(`}`)
}
p.P(`return b`)
p.P(`}`)
p.P()
// LowFlags() returns fields in struct set to 0
p.P(`func (this *`, baseTypeName, `) LowFlags() []string {`)
p.In()
p.P(`var b []string`)
for _, field := range message.Field {
fieldname := p.GetFieldName(message, field)
p.P(`if !this.`, fieldname, ` {`)
p.In()
p.P(`b = append(b, "`, snakeCase(fieldname), `")`)
p.Out()
p.P(`}`)
}
p.P(`return b`)
p.P(`}`)
p.P()
// Sets a flag or returns error
p.P(`func (this *`, baseTypeName, `) SetFlag(flag string) error {`)
p.In()
p.P(`switch flag {`)
p.In()
for _, field := range message.Field {
fieldname := p.GetFieldName(message, field)
p.P(`case "`, snakeCase(fieldname), `":`)
p.In()
p.P(`this.`, fieldname, `= true `)
p.Out()
}
p.P(`default:`)
p.In()
p.P(`return fmt.Errorf("invalid flag: %v", flag)`)
p.Out()
p.P(`}`)
p.Out()
p.P(`return nil`)
p.P(`}`)
// Sets a flag or returns error
p.P(`func (this *`, baseTypeName, `) ClearFlag(flag string) error {`)
p.In()
p.P(`switch flag {`)
p.In()
for _, field := range message.Field {
fieldname := p.GetFieldName(message, field)
p.P(`case "`, snakeCase(fieldname), `":`)
p.In()
p.P(`this.`, fieldname, `= false `)
p.Out()
}
p.P(`default:`)
p.In()
p.P(`return fmt.Errorf("invalid flag: %v", flag)`)
p.Out()
p.P(`}`)
p.Out()
p.P(`return nil`)
p.P(`}`)
// SetFlags(...string) []error
// sets a number of flags which correspond to fields in the struct
p.P(`func (this *`, baseTypeName, `) SetFlags(flags ...string) []error {`)
p.In()
p.P(`var errs []error`)
p.P(`for _, f := range flags {`)
p.In()
p.P(`if err := this.SetFlag(f); err != nil {`)
p.In()
p.P(`errs = append(errs, err)`)
p.Out()
p.P(`}`)
p.Out()
p.P(`}`)
p.P(`return errs`)
p.Out()
p.P(`}`)
// ClearFlags(...string) []error
// sets a number of flags which correspond to fields in the struct
p.P(`func (this *`, baseTypeName, `) ClearFlags(flags ...string) []error {`)
p.In()
p.P(`var errs []error`)
p.P(`for _, f := range flags {`)
p.In()
p.P(`if err := this.ClearFlag(f); err != nil {`)
p.In()
p.P(`errs = append(errs, err)`)
p.Out()
p.P(`}`)
p.Out()
p.P(`}`)
p.P(`return errs`)
p.Out()
p.P(`}`)
// Returns the value of a flag or false if the flag does not exist
p.P(`func (this *`, baseTypeName, `) TestFlag(flag string) bool {`)
p.In()
p.P(`switch flag {`)
p.In()
for _, field := range message.Field {
fieldname := p.GetFieldName(message, field)
p.P(`case "`, snakeCase(fieldname), `":`)
p.In()
p.P(`return this.`, fieldname)
p.Out()
}
p.P(`}`)
p.Out()
p.P(`return false`)
p.P(`}`)
// TestFlags(...string) []error
// Returns Flag1 AND Flag2 AND ...
p.P(`func (this *`, baseTypeName, `) TestFlags(flags ...string) bool {`)
p.In()
p.P(`for _, f := range flags {`)
p.In()
p.P(`if !this.TestFlag(f) {`)
p.In()
p.P(`return false`)
p.Out()
p.P(`}`)
p.Out()
p.P(`}`)
p.P(`return true`)
p.Out()
p.P(`}`)
// Represent in the database as int64()
// returns a bitflags uint64 representation of the structure
p.P(`func (this *`, baseTypeName, `) Scan(i interface{}) error {`)
p.In()
p.P(`switch v := i.(type) {`)
types := []string{"int", "int32", "float32", "float64"}
for _, t := range types {
p.P(`case `, t, `:`)
p.In()
p.P(`return this.FromUInt64(uint64(v))`)
p.Out()
}
p.P(`}`)
p.P()
p.P(`return fmt.Errorf("invalid type: %T", i)`)
p.P(`}`)
// FromInt64(b uint64)
// TODO(dan) should return error if overflow
p.P(`func (this *`, baseTypeName, `) FromUInt64(b uint64) error {`)
p.In()
p.P(`bb := b`)
for i, field := range message.Field {
p.P(`bb = b`)
fieldname := p.GetFieldName(message, field)
p.P(`if bb&(uint64(1)<<uint(`, i, `)) > 0 {`)
p.In()
p.P(`this.`, fieldname, ` = true`)
p.Out()
p.P(`} else {`)
p.In()
p.P(`this.`, fieldname, ` = false`)
p.P(`}`)
p.Out()
}
p.P()
p.P(`return nil`)
p.P(`}`)
}
}
func snakeCase(in string) string {
runes := []rune(in)
length := len(runes)
out := bytes.NewBuffer(make([]byte, 0, length))
for i := 0; i < length; i++ {
if i > 0 && unicode.IsUpper(runes[i]) && ((i+1 < length && unicode.IsLower(runes[i+1])) || unicode.IsLower(runes[i-1])) {
out.WriteRune('_')
}
out.WriteRune(unicode.ToLower(runes[i]))
}
return out.String()
}
func init() {
generator.RegisterPlugin(NewBitflags())
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
You can’t perform that action at this time.