Created
November 17, 2023 19:54
-
-
Save NickSablukov/fb4725f513e9f10cf8e7928f5ffae015 to your computer and use it in GitHub Desktop.
compare_thrift.go
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package main | |
import ( | |
"errors" | |
"fmt" | |
"os" | |
"github.com/samuel/go-thrift/parser" | |
) | |
const ( | |
required = "required" | |
optional = "optional" | |
oneway = "oneway" | |
dually = "dually" | |
mapType = "map" | |
listType = "list" | |
setType = "set" | |
) | |
func main() { | |
if len(os.Args) != 3 { | |
fmt.Println("missing args. For example: `main.go client.thrift server.thrift`") | |
os.Exit(1) | |
return | |
} | |
client, err := getThrift(os.Args[1]) | |
if err != nil { | |
fmt.Println(err.Error()) | |
os.Exit(1) | |
return | |
} | |
server, err := getThrift(os.Args[2]) | |
if err != nil { | |
fmt.Println(err.Error()) | |
os.Exit(1) | |
return | |
} | |
// compare client with server | |
// TODO: do we need to check annotations? оО | |
for _, f := range []func(client, server *parser.Thrift) error{ | |
compareConstants, | |
compareStructs, | |
compareServices, | |
compareExceptions, // TODO - implement me! | |
compareUnions, // TODO - implement me! | |
compareEnums, // TODO - implement me! | |
} { | |
if err := f(client, server); err != nil { | |
println(err.Error()) | |
os.Exit(1) | |
return | |
} | |
} | |
} | |
func compareUnions(client, server *parser.Thrift) error { return nil } | |
func compareExceptions(client, server *parser.Thrift) error { return nil } | |
func compareEnums(client, server *parser.Thrift) error { return nil } | |
func compareConstants(client, server *parser.Thrift) error { | |
serverConstants := elemsMap(server.Constants, func(o *parser.Constant) string { | |
return o.Name | |
}) | |
for _, c := range client.Constants { | |
s, ok := serverConstants[c.Name] | |
switch false { | |
case ok: | |
return errors.New(fmt.Sprintf(`server doesn't have constant "%s"`, c.Name)) | |
// TODO: we'll see "%!s(int64=5)", and we need to humanize it | |
case s.Value == c.Value: | |
return errors.New(fmt.Sprintf( | |
`constant "%s" has value "%s" in server contract, but in client it's "%s"'`, | |
s.Name, | |
s.Value, | |
c.Value, | |
)) | |
} | |
if err := compareTypes(fmt.Sprintf(`constant "%s"`, s.Name), c.Type, s.Type); err != nil { | |
return err | |
} | |
} | |
return nil | |
} | |
func compareField(objectName string, clientField, serverField *parser.Field) error { | |
switch false { | |
case serverField.ID == clientField.ID: | |
return errors.New(fmt.Sprintf( | |
`server's %s has field "%s" with ID "%d", but client declare "%d"'`, | |
objectName, | |
serverField.Name, | |
serverField.ID, | |
clientField.ID, | |
)) | |
case serverField.Default == clientField.Default: | |
return errors.New(fmt.Sprintf( | |
`server's %s has field "%s" with default value "%v", but client's "%s.%s" declare "%v"'`, | |
objectName, | |
serverField.Name, | |
serverField.Default, | |
clientField.Default, | |
)) | |
case serverField.Optional == clientField.Optional: | |
var ( | |
_s = required | |
_c = required | |
) | |
if serverField.Optional { | |
_s = optional | |
} | |
if clientField.Optional { | |
_c = optional | |
} | |
return errors.New(fmt.Sprintf( | |
`server's %s has %s field "%s", but client declare it as %s'`, | |
objectName, | |
_s, | |
serverField.Name, | |
_c, | |
)) | |
} | |
if err := compareTypes( | |
fmt.Sprintf(`%s's field "%s"`, objectName, serverField.Name), | |
clientField.Type, | |
serverField.Type, | |
); err != nil { | |
return err | |
} | |
return nil | |
} | |
func compareStructs(client, server *parser.Thrift) error { | |
serverStructs := elemsMap(server.Structs, func(o *parser.Struct) string { | |
return o.Name | |
}) | |
for _, c := range client.Structs { | |
s, ok := serverStructs[c.Name] | |
if !ok { | |
return errors.New(fmt.Sprintf(`server doesn't have struct "%s"`, c.Name)) | |
} | |
sFields := elems(s.Fields, func(o *parser.Field) string { | |
return o.Name | |
}) | |
for _, field := range c.Fields { | |
sField, ok := sFields[field.Name] | |
if !ok { | |
return errors.New(fmt.Sprintf( | |
`server's struct "%s" doesn't have field "%s", but client's "%s" declare it'`, | |
s.Name, | |
field.Name, | |
c.Name, | |
)) | |
} | |
if err := compareField(fmt.Sprintf(`struct "%s"`, s.Name), field, sField); err != nil { | |
return err | |
} | |
} | |
} | |
return nil | |
} | |
func compareTypes(objectName string, clientType, serverType *parser.Type) error { | |
fullClientFieldType := representType(clientType) | |
fullServerFieldType := representType(serverType) | |
if fullClientFieldType != fullServerFieldType { | |
return errors.New(fmt.Sprintf( | |
`server's %s has type %s, but client declare %s'`, | |
objectName, | |
fullServerFieldType, | |
fullClientFieldType, | |
)) | |
} | |
return nil | |
} | |
func compareServices(client, server *parser.Thrift) error { | |
serverServices := elemsMap(server.Services, func(o *parser.Service) string { | |
return o.Name | |
}) | |
for _, c := range client.Services { | |
s, ok := serverServices[c.Name] | |
switch false { | |
case ok: | |
return errors.New(fmt.Sprintf( | |
`server doesn't have service "%s", but client declare it`, | |
c.Name, | |
)) | |
case s.Extends == c.Extends: | |
return errors.New(fmt.Sprintf( | |
`server's service "%s" has extensds "%s", but client's "%s" declare "%s"`, | |
s.Name, | |
s.Extends, | |
c.Name, | |
c.Extends, | |
)) | |
} | |
sMethods := elemsMap(s.Methods, func(o *parser.Method) string { | |
return o.Name | |
}) | |
for _, cM := range c.Methods { | |
sM, ok := sMethods[cM.Name] | |
if !ok { | |
return errors.New(fmt.Sprintf( | |
`server doesn't have method "%s.%s", but client declare it`, | |
c.Name, | |
cM.Name, | |
)) | |
} | |
if err := compareMethod(c.Name, cM, sM); err != nil { | |
return err | |
} | |
} | |
} | |
return nil | |
} | |
func compareMethod(structName string, clientMethod, serverMethod *parser.Method) error { | |
if serverMethod.Oneway != clientMethod.Oneway { | |
_s := oneway | |
_c := oneway | |
if !serverMethod.Oneway { | |
_s = dually | |
} | |
if !clientMethod.Oneway { | |
_c = dually | |
} | |
return errors.New(fmt.Sprintf( | |
`server's method "%s.%s" is %s, but client's "%s.%s" is %s`, | |
structName, | |
serverMethod.Name, | |
_s, | |
structName, | |
clientMethod.Name, | |
_c, | |
)) | |
} | |
if err := compareTypes( | |
fmt.Sprintf(`method's "%s.%s" returning`, structName, serverMethod.Name), | |
clientMethod.ReturnType, | |
serverMethod.ReturnType, | |
); err != nil { | |
return err | |
} | |
if len(clientMethod.Arguments) != len(serverMethod.Arguments) { | |
return errors.New(fmt.Sprintf( | |
`server's method "%s.%s" has %d arguments, but client's "%s.%s" has %d`, | |
structName, | |
serverMethod.Name, | |
len(serverMethod.Arguments), | |
structName, | |
clientMethod.Name, | |
len(clientMethod.Arguments), | |
)) | |
} | |
for i := 0; i < len(serverMethod.Arguments); i++ { | |
cArg := clientMethod.Arguments[i] | |
sArg := serverMethod.Arguments[i] | |
switch false { | |
case cArg.Name == sArg.Name: | |
return errors.New(fmt.Sprintf( | |
`server method's "%s.%s" field №%d has name "%s", but client's declare "%s"'`, | |
structName, | |
clientMethod.Name, | |
i+1, | |
sArg.Name, | |
cArg.Name, | |
)) | |
case cArg.ID == sArg.ID: | |
return errors.New(fmt.Sprintf( | |
`server method's "%s.%s" field "%s" has ID "%d", but client's declare "%d"'`, | |
structName, | |
clientMethod.Name, | |
sArg.Name, | |
sArg.ID, | |
cArg.ID, | |
)) | |
case cArg.Optional == sArg.Optional: | |
var ( | |
_s = required | |
_c = required | |
) | |
if sArg.Optional { | |
_s = optional | |
} | |
if cArg.Optional { | |
_c = optional | |
} | |
return errors.New(fmt.Sprintf( | |
`server method's "%s.%s" field "%s" is %s, but client's is %s`, | |
structName, | |
clientMethod.Name, | |
sArg.Name, | |
_s, | |
_c, | |
)) | |
} | |
if err := compareTypes( | |
fmt.Sprintf(`method's "%s.%s" argument "%s"`, structName, serverMethod.Name, cArg.Name), | |
cArg.Type, | |
sArg.Type, | |
); err != nil { | |
return err | |
} | |
} | |
return nil | |
} | |
func getThrift(file string) (*parser.Thrift, error) { | |
fileContent, err := os.ReadFile(file) | |
if err != nil { | |
return nil, err | |
} | |
r, err := parser.Parse(file, fileContent) | |
if err != nil { | |
return nil, err | |
} | |
return r.(*parser.Thrift), nil | |
} | |
func elemsMap[Obj any, Key comparable](arr map[string]Obj, f func(o Obj) Key) map[Key]Obj { | |
res := make(map[Key]Obj, len(arr)) | |
for _, elem := range arr { | |
res[f(elem)] = elem | |
} | |
return res | |
} | |
func elems[Obj any, Key comparable](arr []Obj, f func(o Obj) Key) map[Key]Obj { | |
res := make(map[Key]Obj, len(arr)) | |
for _, elem := range arr { | |
res[f(elem)] = elem | |
} | |
return res | |
} | |
func representType(f *parser.Type) string { | |
res := f.Name | |
switch f.Name { | |
case mapType: | |
res += fmt.Sprintf("<%s,%s>", representType(f.KeyType), representType(f.ValueType)) | |
case listType, setType: | |
res += fmt.Sprintf("<%s>", representType(f.ValueType)) | |
} | |
return res | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment