Last active
June 13, 2023 17:32
-
-
Save stillmatic/30228bc13460dc6b3c589e34c6afa502 to your computer and use it in GitHub Desktop.
golang generic struct validator
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package gallant | |
import ( | |
"context" | |
"encoding/json" | |
"reflect" | |
"github.com/go-playground/validator/v10" | |
"github.com/pkg/errors" | |
"gopkg.in/yaml.v3" | |
) | |
var v = validator.New() | |
// Parser is an interface for parsing strings into structs | |
// It is threadsafe and can be used concurrently. The underlying validator is threadsafe as well. | |
type Parser[T any] interface { | |
Parse(ctx context.Context, input string) (T, error) | |
GetFormatString() string | |
} | |
// JSONParser is a parser that parses arbitrary JSON structs | |
// It is threadsafe and can be used concurrently. The underlying validator is threadsafe as well. | |
type JSONParser[T any] struct { | |
validate bool | |
format string | |
} | |
// NewJSONParser returns a new JSONParser | |
// If validate is true, the parser will validate the struct using the struct tags | |
// The underlying struct must have the `validate` tag if validation is enabled, otherwise it will return `InvalidValidationError` | |
func NewJSONParser[T any](validate bool) *JSONParser[T] { | |
var t T | |
format := "" | |
if validate { | |
format += "type " | |
tt := reflect.TypeOf(t) | |
format += tt.Name() + " struct {\n" | |
for i := 0; i < tt.NumField(); i++ { | |
field := tt.Field(i) | |
format += field.Name + "`json:" + field.Tag.Get("json") | |
if field.Tag.Get("validate") != "" { | |
format += " validate:" + field.Tag.Get("validate") + "`" | |
} | |
format += "\n" | |
} | |
format += "}" | |
} | |
return &JSONParser[T]{ | |
validate: validate, | |
format: format, | |
} | |
} | |
func (p *JSONParser[T]) GetFormatString() string { | |
return p.format | |
} | |
func (p *JSONParser[T]) Parse(ctx context.Context, input string) (T, error) { | |
var t T | |
err := json.Unmarshal([]byte(input), &t) | |
if err != nil { | |
return t, err | |
} | |
if p.validate { | |
err = v.Struct(t) | |
if err != nil { | |
return t, errors.Wrap(err, "failed to validate struct") | |
} | |
} | |
return t, nil | |
} | |
// YAMLParser converts strings into YAML structs | |
// It is threadsafe and can be used concurrently. The underlying validator is threadsafe as well. | |
type YAMLParser[T any] struct { | |
validate bool | |
format string | |
} | |
// NewYamlParser returns a new YAMLParser | |
// If validate is true, the parser will validate the struct using the struct tags | |
// The underlying struct must have the `validate` tag if validation is enabled, otherwise it will return `InvalidValidationError` | |
func NewYamlParser[T any](validate bool) *YAMLParser[T] { | |
var t T | |
format := "" | |
if validate { | |
format += "type " | |
tt := reflect.TypeOf(t) | |
format += tt.Name() + " struct {\n" | |
for i := 0; i < tt.NumField(); i++ { | |
field := tt.Field(i) | |
format += field.Name + "`yaml:" + field.Tag.Get("yaml") | |
if field.Tag.Get("validate") != "" { | |
format += " validate:" + field.Tag.Get("validate") + "`" | |
} | |
format += "\n" | |
} | |
format += "}" | |
} | |
return &YAMLParser[T]{ | |
validate: validate, | |
format: format, | |
} | |
} | |
func (p *YAMLParser[T]) GetFormatString() string { | |
return p.format | |
} | |
func (p *YAMLParser[T]) Parse(ctx context.Context, input string) (T, error) { | |
var t T | |
err := yaml.Unmarshal([]byte(input), &t) | |
if err != nil { | |
return t, err | |
} | |
if p.validate { | |
err = v.Struct(t) | |
if err != nil { | |
return t, errors.Wrap(err, "failed to validate struct") | |
} | |
} | |
return t, nil | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package gallant_test | |
import ( | |
"context" | |
"encoding/json" | |
"fmt" | |
"testing" | |
"github.com/sashabaranov/go-openai" | |
"github.com/stillmatic/gallant/gallant" | |
"github.com/stretchr/testify/assert" | |
"gopkg.in/yaml.v3" | |
) | |
type employee struct { | |
Name string `json:"name" yaml:"name" validate:"required"` | |
Age int `json:"age" yaml:"age" validate:"gte=18,lte=60"` | |
} | |
type company struct { | |
Name string `json:"name" yaml:"name"` | |
Employees []*employee `json:"employees" yaml:"employees"` | |
} | |
type bulletList []string | |
var testCo = company{ | |
Name: "Acme", | |
Employees: []*employee{ | |
{ | |
Name: "John", | |
Age: 30, | |
}, | |
{ | |
Name: "Jane", | |
Age: 25, | |
}, | |
}, | |
} | |
var badEmployees = []employee{ | |
{ | |
Name: "John", | |
Age: 0, | |
}, | |
{ | |
Name: "", | |
Age: 25, | |
}, | |
} | |
func TestParsers(t *testing.T) { | |
t.Run("JSONParser", func(t *testing.T) { | |
jsonParser := gallant.NewJSONParser[company](false) | |
input, err := json.Marshal(testCo) | |
assert.NoError(t, err) | |
actual, err := jsonParser.Parse(context.Background(), string(input)) | |
assert.NoError(t, err) | |
assert.Equal(t, testCo, actual) | |
// test failure | |
employeeParser := gallant.NewJSONParser[employee](true) | |
input2, err := json.Marshal(badEmployees) | |
assert.NoError(t, err) | |
_, err = employeeParser.Parse(context.Background(), string(input2)) | |
assert.Error(t, err) | |
}) | |
t.Run("YAMLParser", func(t *testing.T) { | |
// test struct | |
yamlParser := gallant.NewYamlParser[company](true) | |
input, err := yaml.Marshal(testCo) | |
assert.NoError(t, err) | |
actual, err := yamlParser.Parse(context.Background(), string(input)) | |
assert.NoError(t, err) | |
assert.Equal(t, testCo, actual) | |
// Test bullet list | |
yamlParserBullet := gallant.NewYamlParser[bulletList](false) | |
input2 := `- bullet point 1 | |
- bullet point 2 | |
- bullet point 3` | |
expected := []string{"bullet point 1", "bullet point 2", "bullet point 3"} | |
actual2, err := yamlParserBullet.Parse(context.Background(), input2) | |
assert.NoError(t, err) | |
for i, v := range actual2 { | |
assert.Equal(t, expected[i], v) | |
} | |
}) | |
t.Run("fail validation", func(t *testing.T) { | |
yamlParser := gallant.NewYamlParser[employee](false) | |
input, err := yaml.Marshal(badEmployees) | |
assert.NoError(t, err) | |
resp, err := yamlParser.Parse(context.Background(), string(input)) | |
assert.Error(t, err) | |
_ = resp | |
}) | |
} | |
type location struct { | |
City string `json:"city" yaml:"city" validate:"required"` | |
State string `json:"state" yaml:"state" validate:"required"` | |
Country string `json:"country" yaml:"country" validate:"required"` | |
Latitude float64 `json:"latitude" yaml:"latitude" validate:"required"` | |
Longitude float64 `json:"longitude" yaml:"longitude" validate:"required"` | |
} | |
func marvin[T any](t T, input string, model gallant.LLM) (output T, err error) { | |
yamlParser := gallant.NewYamlParser[T](false) | |
promptTemplate := `## Objective | |
Your job is to generate a YAML representation for a Golang struct with the following tag definitions: | |
%s | |
Given unstructured text infer when possible the missing data. Do not give any additional detail, instructions, or even | |
punctuation; respond ONLY with a return value that can be parsed into the | |
expected YAML form.` | |
inputTemplate := `The function was called with the following input: | |
%s | |
# Instructions | |
Given unstructured text infer when possible the missing data. Generate the function's output. Do not explain the type signature or give | |
guidance on parsing. | |
` | |
inputMessage := fmt.Sprintf(inputTemplate, input) | |
prompt := fmt.Sprintf(promptTemplate, yamlParser.GetFormatString()) | |
messages := []openai.ChatCompletionMessage{ | |
{ | |
Role: openai.ChatMessageRoleSystem, | |
Content: prompt, | |
}, | |
{ | |
Role: openai.ChatMessageRoleUser, | |
Content: inputMessage, | |
}, | |
} | |
callAndParse := func(messages []openai.ChatCompletionMessage) (respStr string, outStruct T, err error) { | |
req := openai.ChatCompletionRequest{ | |
Model: openai.GPT3Dot5Turbo, | |
Messages: messages, | |
Temperature: 0.0, | |
MaxTokens: 256, | |
} | |
ctx := context.Background() | |
resp, err := model.CreateChatCompletion(ctx, req) | |
if err != nil { | |
return "", t, err | |
} | |
respStr = resp.Choices[0].Message.Content | |
outStruct, err = yamlParser.Parse(ctx, respStr) | |
return respStr, outStruct, err | |
} | |
respStr, outStruct, err := callAndParse(messages) | |
if err != nil { | |
// retry this once | |
messages = append(messages, openai.ChatCompletionMessage{ | |
Role: openai.ChatMessageRoleAssistant, | |
Content: respStr, | |
}) | |
retryMsg := `An error occurred: %s | |
Remember you must follow the instructions exactly. | |
Given unstructured text infer the struct. Use YAML to return the response, the format is: | |
%s | |
The input is: | |
%s | |
` | |
messages = append(messages, openai.ChatCompletionMessage{ | |
Role: openai.ChatMessageRoleSystem, | |
Content: fmt.Sprintf(retryMsg, err.Error(), yamlParser.GetFormatString(), input), | |
}) | |
_, outStruct, err = callAndParse(messages) | |
if err != nil { | |
return t, err | |
} | |
} | |
return outStruct, nil | |
} | |
func TestMarvin(t *testing.T) { | |
// adapted from: | |
// https://github.com/PrefectHQ/marvin/blob/e6445205fe37df37bfc9375f67cdef8ef111ef16/src/marvin/bot/base.py#L45 | |
// https://github.com/PrefectHQ/marvin/blob/e6445205fe37df37bfc9375f67cdef8ef111ef16/src/marvin/ai_models/base.py#L40 | |
// oaiClient := openai.NewClient(os.Getenv("OPENAI_API_KEY")) | |
// model := gallant.LLM(oaiClient) | |
// var location location | |
// resp, err := marvin(location, "No way, I'm also from the windy city!", model) | |
// assert.NoError(t, err) | |
// assert.Equal(t, "Chicago", resp.City) | |
// assert.Equal(t, "Illinois", resp.State) | |
// assert.Equal(t, "United States", resp.Country) | |
} | |
func BenchmarkParser(b *testing.B) { | |
b.Run("JSONParser-NoValidate", func(b *testing.B) { | |
jsonParser := gallant.NewJSONParser[company](false) | |
input, err := json.Marshal(testCo) | |
assert.NoError(b, err) | |
b.ResetTimer() | |
for i := 0; i < b.N; i++ { | |
actual, err := jsonParser.Parse(context.Background(), string(input)) | |
assert.NoError(b, err) | |
assert.Equal(b, testCo, actual) | |
} | |
}) | |
b.Run("JSONParser-Validate", func(b *testing.B) { | |
jsonParser := gallant.NewJSONParser[company](true) | |
input, err := json.Marshal(testCo) | |
assert.NoError(b, err) | |
b.ResetTimer() | |
for i := 0; i < b.N; i++ { | |
actual, err := jsonParser.Parse(context.Background(), string(input)) | |
assert.NoError(b, err) | |
assert.Equal(b, testCo, actual) | |
} | |
}) | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment