Skip to content

Instantly share code, notes, and snippets.

@stillmatic
Last active June 13, 2023 17:32
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save stillmatic/30228bc13460dc6b3c589e34c6afa502 to your computer and use it in GitHub Desktop.
Save stillmatic/30228bc13460dc6b3c589e34c6afa502 to your computer and use it in GitHub Desktop.
golang generic struct validator
package gallant
import (
"context"
"encoding/json"
"reflect"
"github.com/go-playground/validator/v10"
"github.com/pkg/errors"
"gopkg.in/yaml.v3"
)
var v = validator.New()
// Parser is an interface for parsing strings into structs
// It is threadsafe and can be used concurrently. The underlying validator is threadsafe as well.
type Parser[T any] interface {
Parse(ctx context.Context, input string) (T, error)
GetFormatString() string
}
// JSONParser is a parser that parses arbitrary JSON structs
// It is threadsafe and can be used concurrently. The underlying validator is threadsafe as well.
type JSONParser[T any] struct {
validate bool
format string
}
// NewJSONParser returns a new JSONParser
// If validate is true, the parser will validate the struct using the struct tags
// The underlying struct must have the `validate` tag if validation is enabled, otherwise it will return `InvalidValidationError`
func NewJSONParser[T any](validate bool) *JSONParser[T] {
var t T
format := ""
if validate {
format += "type "
tt := reflect.TypeOf(t)
format += tt.Name() + " struct {\n"
for i := 0; i < tt.NumField(); i++ {
field := tt.Field(i)
format += field.Name + "`json:" + field.Tag.Get("json")
if field.Tag.Get("validate") != "" {
format += " validate:" + field.Tag.Get("validate") + "`"
}
format += "\n"
}
format += "}"
}
return &JSONParser[T]{
validate: validate,
format: format,
}
}
func (p *JSONParser[T]) GetFormatString() string {
return p.format
}
func (p *JSONParser[T]) Parse(ctx context.Context, input string) (T, error) {
var t T
err := json.Unmarshal([]byte(input), &t)
if err != nil {
return t, err
}
if p.validate {
err = v.Struct(t)
if err != nil {
return t, errors.Wrap(err, "failed to validate struct")
}
}
return t, nil
}
// YAMLParser converts strings into YAML structs
// It is threadsafe and can be used concurrently. The underlying validator is threadsafe as well.
type YAMLParser[T any] struct {
validate bool
format string
}
// NewYamlParser returns a new YAMLParser
// If validate is true, the parser will validate the struct using the struct tags
// The underlying struct must have the `validate` tag if validation is enabled, otherwise it will return `InvalidValidationError`
func NewYamlParser[T any](validate bool) *YAMLParser[T] {
var t T
format := ""
if validate {
format += "type "
tt := reflect.TypeOf(t)
format += tt.Name() + " struct {\n"
for i := 0; i < tt.NumField(); i++ {
field := tt.Field(i)
format += field.Name + "`yaml:" + field.Tag.Get("yaml")
if field.Tag.Get("validate") != "" {
format += " validate:" + field.Tag.Get("validate") + "`"
}
format += "\n"
}
format += "}"
}
return &YAMLParser[T]{
validate: validate,
format: format,
}
}
func (p *YAMLParser[T]) GetFormatString() string {
return p.format
}
func (p *YAMLParser[T]) Parse(ctx context.Context, input string) (T, error) {
var t T
err := yaml.Unmarshal([]byte(input), &t)
if err != nil {
return t, err
}
if p.validate {
err = v.Struct(t)
if err != nil {
return t, errors.Wrap(err, "failed to validate struct")
}
}
return t, nil
}
package gallant_test
import (
"context"
"encoding/json"
"fmt"
"testing"
"github.com/sashabaranov/go-openai"
"github.com/stillmatic/gallant/gallant"
"github.com/stretchr/testify/assert"
"gopkg.in/yaml.v3"
)
type employee struct {
Name string `json:"name" yaml:"name" validate:"required"`
Age int `json:"age" yaml:"age" validate:"gte=18,lte=60"`
}
type company struct {
Name string `json:"name" yaml:"name"`
Employees []*employee `json:"employees" yaml:"employees"`
}
type bulletList []string
var testCo = company{
Name: "Acme",
Employees: []*employee{
{
Name: "John",
Age: 30,
},
{
Name: "Jane",
Age: 25,
},
},
}
var badEmployees = []employee{
{
Name: "John",
Age: 0,
},
{
Name: "",
Age: 25,
},
}
func TestParsers(t *testing.T) {
t.Run("JSONParser", func(t *testing.T) {
jsonParser := gallant.NewJSONParser[company](false)
input, err := json.Marshal(testCo)
assert.NoError(t, err)
actual, err := jsonParser.Parse(context.Background(), string(input))
assert.NoError(t, err)
assert.Equal(t, testCo, actual)
// test failure
employeeParser := gallant.NewJSONParser[employee](true)
input2, err := json.Marshal(badEmployees)
assert.NoError(t, err)
_, err = employeeParser.Parse(context.Background(), string(input2))
assert.Error(t, err)
})
t.Run("YAMLParser", func(t *testing.T) {
// test struct
yamlParser := gallant.NewYamlParser[company](true)
input, err := yaml.Marshal(testCo)
assert.NoError(t, err)
actual, err := yamlParser.Parse(context.Background(), string(input))
assert.NoError(t, err)
assert.Equal(t, testCo, actual)
// Test bullet list
yamlParserBullet := gallant.NewYamlParser[bulletList](false)
input2 := `- bullet point 1
- bullet point 2
- bullet point 3`
expected := []string{"bullet point 1", "bullet point 2", "bullet point 3"}
actual2, err := yamlParserBullet.Parse(context.Background(), input2)
assert.NoError(t, err)
for i, v := range actual2 {
assert.Equal(t, expected[i], v)
}
})
t.Run("fail validation", func(t *testing.T) {
yamlParser := gallant.NewYamlParser[employee](false)
input, err := yaml.Marshal(badEmployees)
assert.NoError(t, err)
resp, err := yamlParser.Parse(context.Background(), string(input))
assert.Error(t, err)
_ = resp
})
}
type location struct {
City string `json:"city" yaml:"city" validate:"required"`
State string `json:"state" yaml:"state" validate:"required"`
Country string `json:"country" yaml:"country" validate:"required"`
Latitude float64 `json:"latitude" yaml:"latitude" validate:"required"`
Longitude float64 `json:"longitude" yaml:"longitude" validate:"required"`
}
func marvin[T any](t T, input string, model gallant.LLM) (output T, err error) {
yamlParser := gallant.NewYamlParser[T](false)
promptTemplate := `## Objective
Your job is to generate a YAML representation for a Golang struct with the following tag definitions:
%s
Given unstructured text infer when possible the missing data. Do not give any additional detail, instructions, or even
punctuation; respond ONLY with a return value that can be parsed into the
expected YAML form.`
inputTemplate := `The function was called with the following input:
%s
# Instructions
Given unstructured text infer when possible the missing data. Generate the function's output. Do not explain the type signature or give
guidance on parsing.
`
inputMessage := fmt.Sprintf(inputTemplate, input)
prompt := fmt.Sprintf(promptTemplate, yamlParser.GetFormatString())
messages := []openai.ChatCompletionMessage{
{
Role: openai.ChatMessageRoleSystem,
Content: prompt,
},
{
Role: openai.ChatMessageRoleUser,
Content: inputMessage,
},
}
callAndParse := func(messages []openai.ChatCompletionMessage) (respStr string, outStruct T, err error) {
req := openai.ChatCompletionRequest{
Model: openai.GPT3Dot5Turbo,
Messages: messages,
Temperature: 0.0,
MaxTokens: 256,
}
ctx := context.Background()
resp, err := model.CreateChatCompletion(ctx, req)
if err != nil {
return "", t, err
}
respStr = resp.Choices[0].Message.Content
outStruct, err = yamlParser.Parse(ctx, respStr)
return respStr, outStruct, err
}
respStr, outStruct, err := callAndParse(messages)
if err != nil {
// retry this once
messages = append(messages, openai.ChatCompletionMessage{
Role: openai.ChatMessageRoleAssistant,
Content: respStr,
})
retryMsg := `An error occurred: %s
Remember you must follow the instructions exactly.
Given unstructured text infer the struct. Use YAML to return the response, the format is:
%s
The input is:
%s
`
messages = append(messages, openai.ChatCompletionMessage{
Role: openai.ChatMessageRoleSystem,
Content: fmt.Sprintf(retryMsg, err.Error(), yamlParser.GetFormatString(), input),
})
_, outStruct, err = callAndParse(messages)
if err != nil {
return t, err
}
}
return outStruct, nil
}
func TestMarvin(t *testing.T) {
// adapted from:
// https://github.com/PrefectHQ/marvin/blob/e6445205fe37df37bfc9375f67cdef8ef111ef16/src/marvin/bot/base.py#L45
// https://github.com/PrefectHQ/marvin/blob/e6445205fe37df37bfc9375f67cdef8ef111ef16/src/marvin/ai_models/base.py#L40
// oaiClient := openai.NewClient(os.Getenv("OPENAI_API_KEY"))
// model := gallant.LLM(oaiClient)
// var location location
// resp, err := marvin(location, "No way, I'm also from the windy city!", model)
// assert.NoError(t, err)
// assert.Equal(t, "Chicago", resp.City)
// assert.Equal(t, "Illinois", resp.State)
// assert.Equal(t, "United States", resp.Country)
}
func BenchmarkParser(b *testing.B) {
b.Run("JSONParser-NoValidate", func(b *testing.B) {
jsonParser := gallant.NewJSONParser[company](false)
input, err := json.Marshal(testCo)
assert.NoError(b, err)
b.ResetTimer()
for i := 0; i < b.N; i++ {
actual, err := jsonParser.Parse(context.Background(), string(input))
assert.NoError(b, err)
assert.Equal(b, testCo, actual)
}
})
b.Run("JSONParser-Validate", func(b *testing.B) {
jsonParser := gallant.NewJSONParser[company](true)
input, err := json.Marshal(testCo)
assert.NoError(b, err)
b.ResetTimer()
for i := 0; i < b.N; i++ {
actual, err := jsonParser.Parse(context.Background(), string(input))
assert.NoError(b, err)
assert.Equal(b, testCo, actual)
}
})
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment