Skip to content

Instantly share code, notes, and snippets.

@NickSablukov
Created November 17, 2023 19:54
Show Gist options
  • Save NickSablukov/fb4725f513e9f10cf8e7928f5ffae015 to your computer and use it in GitHub Desktop.
Save NickSablukov/fb4725f513e9f10cf8e7928f5ffae015 to your computer and use it in GitHub Desktop.
compare_thrift.go
package main
import (
"errors"
"fmt"
"os"
"github.com/samuel/go-thrift/parser"
)
const (
required = "required"
optional = "optional"
oneway = "oneway"
dually = "dually"
mapType = "map"
listType = "list"
setType = "set"
)
func main() {
if len(os.Args) != 3 {
fmt.Println("missing args. For example: `main.go client.thrift server.thrift`")
os.Exit(1)
return
}
client, err := getThrift(os.Args[1])
if err != nil {
fmt.Println(err.Error())
os.Exit(1)
return
}
server, err := getThrift(os.Args[2])
if err != nil {
fmt.Println(err.Error())
os.Exit(1)
return
}
// compare client with server
// TODO: do we need to check annotations? оО
for _, f := range []func(client, server *parser.Thrift) error{
compareConstants,
compareStructs,
compareServices,
compareExceptions, // TODO - implement me!
compareUnions, // TODO - implement me!
compareEnums, // TODO - implement me!
} {
if err := f(client, server); err != nil {
println(err.Error())
os.Exit(1)
return
}
}
}
func compareUnions(client, server *parser.Thrift) error { return nil }
func compareExceptions(client, server *parser.Thrift) error { return nil }
func compareEnums(client, server *parser.Thrift) error { return nil }
func compareConstants(client, server *parser.Thrift) error {
serverConstants := elemsMap(server.Constants, func(o *parser.Constant) string {
return o.Name
})
for _, c := range client.Constants {
s, ok := serverConstants[c.Name]
switch false {
case ok:
return errors.New(fmt.Sprintf(`server doesn't have constant "%s"`, c.Name))
// TODO: we'll see "%!s(int64=5)", and we need to humanize it
case s.Value == c.Value:
return errors.New(fmt.Sprintf(
`constant "%s" has value "%s" in server contract, but in client it's "%s"'`,
s.Name,
s.Value,
c.Value,
))
}
if err := compareTypes(fmt.Sprintf(`constant "%s"`, s.Name), c.Type, s.Type); err != nil {
return err
}
}
return nil
}
func compareField(objectName string, clientField, serverField *parser.Field) error {
switch false {
case serverField.ID == clientField.ID:
return errors.New(fmt.Sprintf(
`server's %s has field "%s" with ID "%d", but client declare "%d"'`,
objectName,
serverField.Name,
serverField.ID,
clientField.ID,
))
case serverField.Default == clientField.Default:
return errors.New(fmt.Sprintf(
`server's %s has field "%s" with default value "%v", but client's "%s.%s" declare "%v"'`,
objectName,
serverField.Name,
serverField.Default,
clientField.Default,
))
case serverField.Optional == clientField.Optional:
var (
_s = required
_c = required
)
if serverField.Optional {
_s = optional
}
if clientField.Optional {
_c = optional
}
return errors.New(fmt.Sprintf(
`server's %s has %s field "%s", but client declare it as %s'`,
objectName,
_s,
serverField.Name,
_c,
))
}
if err := compareTypes(
fmt.Sprintf(`%s's field "%s"`, objectName, serverField.Name),
clientField.Type,
serverField.Type,
); err != nil {
return err
}
return nil
}
func compareStructs(client, server *parser.Thrift) error {
serverStructs := elemsMap(server.Structs, func(o *parser.Struct) string {
return o.Name
})
for _, c := range client.Structs {
s, ok := serverStructs[c.Name]
if !ok {
return errors.New(fmt.Sprintf(`server doesn't have struct "%s"`, c.Name))
}
sFields := elems(s.Fields, func(o *parser.Field) string {
return o.Name
})
for _, field := range c.Fields {
sField, ok := sFields[field.Name]
if !ok {
return errors.New(fmt.Sprintf(
`server's struct "%s" doesn't have field "%s", but client's "%s" declare it'`,
s.Name,
field.Name,
c.Name,
))
}
if err := compareField(fmt.Sprintf(`struct "%s"`, s.Name), field, sField); err != nil {
return err
}
}
}
return nil
}
func compareTypes(objectName string, clientType, serverType *parser.Type) error {
fullClientFieldType := representType(clientType)
fullServerFieldType := representType(serverType)
if fullClientFieldType != fullServerFieldType {
return errors.New(fmt.Sprintf(
`server's %s has type %s, but client declare %s'`,
objectName,
fullServerFieldType,
fullClientFieldType,
))
}
return nil
}
func compareServices(client, server *parser.Thrift) error {
serverServices := elemsMap(server.Services, func(o *parser.Service) string {
return o.Name
})
for _, c := range client.Services {
s, ok := serverServices[c.Name]
switch false {
case ok:
return errors.New(fmt.Sprintf(
`server doesn't have service "%s", but client declare it`,
c.Name,
))
case s.Extends == c.Extends:
return errors.New(fmt.Sprintf(
`server's service "%s" has extensds "%s", but client's "%s" declare "%s"`,
s.Name,
s.Extends,
c.Name,
c.Extends,
))
}
sMethods := elemsMap(s.Methods, func(o *parser.Method) string {
return o.Name
})
for _, cM := range c.Methods {
sM, ok := sMethods[cM.Name]
if !ok {
return errors.New(fmt.Sprintf(
`server doesn't have method "%s.%s", but client declare it`,
c.Name,
cM.Name,
))
}
if err := compareMethod(c.Name, cM, sM); err != nil {
return err
}
}
}
return nil
}
func compareMethod(structName string, clientMethod, serverMethod *parser.Method) error {
if serverMethod.Oneway != clientMethod.Oneway {
_s := oneway
_c := oneway
if !serverMethod.Oneway {
_s = dually
}
if !clientMethod.Oneway {
_c = dually
}
return errors.New(fmt.Sprintf(
`server's method "%s.%s" is %s, but client's "%s.%s" is %s`,
structName,
serverMethod.Name,
_s,
structName,
clientMethod.Name,
_c,
))
}
if err := compareTypes(
fmt.Sprintf(`method's "%s.%s" returning`, structName, serverMethod.Name),
clientMethod.ReturnType,
serverMethod.ReturnType,
); err != nil {
return err
}
if len(clientMethod.Arguments) != len(serverMethod.Arguments) {
return errors.New(fmt.Sprintf(
`server's method "%s.%s" has %d arguments, but client's "%s.%s" has %d`,
structName,
serverMethod.Name,
len(serverMethod.Arguments),
structName,
clientMethod.Name,
len(clientMethod.Arguments),
))
}
for i := 0; i < len(serverMethod.Arguments); i++ {
cArg := clientMethod.Arguments[i]
sArg := serverMethod.Arguments[i]
switch false {
case cArg.Name == sArg.Name:
return errors.New(fmt.Sprintf(
`server method's "%s.%s" field №%d has name "%s", but client's declare "%s"'`,
structName,
clientMethod.Name,
i+1,
sArg.Name,
cArg.Name,
))
case cArg.ID == sArg.ID:
return errors.New(fmt.Sprintf(
`server method's "%s.%s" field "%s" has ID "%d", but client's declare "%d"'`,
structName,
clientMethod.Name,
sArg.Name,
sArg.ID,
cArg.ID,
))
case cArg.Optional == sArg.Optional:
var (
_s = required
_c = required
)
if sArg.Optional {
_s = optional
}
if cArg.Optional {
_c = optional
}
return errors.New(fmt.Sprintf(
`server method's "%s.%s" field "%s" is %s, but client's is %s`,
structName,
clientMethod.Name,
sArg.Name,
_s,
_c,
))
}
if err := compareTypes(
fmt.Sprintf(`method's "%s.%s" argument "%s"`, structName, serverMethod.Name, cArg.Name),
cArg.Type,
sArg.Type,
); err != nil {
return err
}
}
return nil
}
func getThrift(file string) (*parser.Thrift, error) {
fileContent, err := os.ReadFile(file)
if err != nil {
return nil, err
}
r, err := parser.Parse(file, fileContent)
if err != nil {
return nil, err
}
return r.(*parser.Thrift), nil
}
func elemsMap[Obj any, Key comparable](arr map[string]Obj, f func(o Obj) Key) map[Key]Obj {
res := make(map[Key]Obj, len(arr))
for _, elem := range arr {
res[f(elem)] = elem
}
return res
}
func elems[Obj any, Key comparable](arr []Obj, f func(o Obj) Key) map[Key]Obj {
res := make(map[Key]Obj, len(arr))
for _, elem := range arr {
res[f(elem)] = elem
}
return res
}
func representType(f *parser.Type) string {
res := f.Name
switch f.Name {
case mapType:
res += fmt.Sprintf("<%s,%s>", representType(f.KeyType), representType(f.ValueType))
case listType, setType:
res += fmt.Sprintf("<%s>", representType(f.ValueType))
}
return res
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment