Skip to content

Instantly share code, notes, and snippets.

@hysios
Last active September 20, 2018 12:36
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save hysios/7b5db0815ce094b9f1b194808befdf71 to your computer and use it in GitHub Desktop.
Save hysios/7b5db0815ce094b9f1b194808befdf71 to your computer and use it in GitHub Desktop.
YAML 做定制解析功能
package main
import (
"bytes"
"encoding/json"
"errors"
"fmt"
"log"
"os"
"reflect"
"strconv"
"github.com/ghodss/yaml"
"github.com/kr/pretty"
)
type Config struct {
Train Train
}
type Train struct {
Framework DefaultFramework
}
type DefaultFramework struct {
Any
String string
Raw Framework
}
type Framework struct {
Lang string
Name string
Version string
Sense string
}
// Any 可以做二选一的类型, 如: string 与 struct 的转换
type Any struct {
contains []reflect.Value
}
func AnyOf(objs ...interface{}) Any {
var a Any
for _, obj := range objs {
a.contains = append(a.contains, reflect.ValueOf(obj))
}
return a
}
func bindFramework(boundTo *DefaultFramework) {
boundTo.Any = AnyOf(&boundTo.String, &boundTo.Raw)
}
func (a *Any) UnmarshalJSON(b []byte) error {
fromAnon := reflect.TypeOf(a)
fmt.Printf("PkgPath: %s\n", fromAnon.PkgPath())
fmt.Printf("anonymous: %v", fromAnon.Anonymous)
if scanString(b) {
if v, ok := a.findType(reflect.TypeOf("")); !ok {
return errors.New("don't have string Type")
} else {
var val string
json.Unmarshal(b, &val)
v.SetString(val)
}
} else if scanInt(b) {
if v, ok := a.findType(reflect.TypeOf(0)); !ok {
return errors.New("don't have int Type")
} else {
i, _ := strconv.Atoi(string(b))
v.SetInt(int64(i))
}
} else if scanObject(b) {
if v, ok := a.findStruct(); !ok {
return errors.New("don't have struct Type")
} else {
var (
t = v.Elem().Type()
val = reflect.New(t)
)
if err := json.Unmarshal(b, val.Interface()); err != nil {
log.Printf("json.Unmarshal error %s", err)
}
v.Elem().Set(reflect.ValueOf(val.Elem().Interface()))
}
}
return nil
}
func (a *Any) findType(t reflect.Type) (reflect.Value, bool) {
for _, v := range a.contains {
if v.Kind() == reflect.Ptr {
v = v.Elem()
}
if v.Type() == t {
return v, true
}
}
return reflect.Value{}, false
}
func (a *Any) findStruct() (reflect.Value, bool) {
for _, v := range a.contains {
if reflect.Indirect(v).Kind() == reflect.Struct {
return v, true
}
}
return reflect.Value{}, false
}
func yaml2json(data []byte) []byte {
j2, err := yaml.YAMLToJSON(data)
if err != nil {
log.Fatalf("error: %v", err)
}
return j2
}
func initializeStruct(t reflect.Type, v reflect.Value) {
for i := 0; i < v.NumField(); i++ {
f := v.Field(i)
ft := t.Field(i)
switch ft.Type.Kind() {
case reflect.Map:
f.Set(reflect.MakeMap(ft.Type))
case reflect.Slice:
f.Set(reflect.MakeSlice(ft.Type, 0, 0))
case reflect.Chan:
f.Set(reflect.MakeChan(ft.Type, 0))
case reflect.Struct:
initializeStruct(ft.Type, f)
case reflect.Ptr:
fv := reflect.New(ft.Type.Elem())
initializeStruct(ft.Type.Elem(), fv.Elem())
f.Set(fv)
default:
}
}
}
var (
data1 = []byte(`train:
framework: 1.8.0
`)
data2 = []byte(`train:
framework:
lang: python
name: tensorflow
version: 1.8.0
sense: gpu
`)
)
func scanString(b []byte) bool {
l := len(b)
if b[0] == '"' && b[l-1] == '"' {
return true
} else {
return false
}
}
func scanObject(b []byte) bool {
l := len(b)
if b[0] == '{' && b[l-1] == '}' {
return true
} else {
return false
}
}
func scanInt(b []byte) bool {
// l := len(b)
if b[0] >= '0' && b[0] <= '9' {
return true
} else if b[0] == '-' || b[0] == '+' || b[0] == '.' {
return true
} else {
return false
}
}
func scanBool(b []byte) bool {
if string(b) == "false" && string(b) == "true" {
return true
} else {
return false
}
}
func build() *Config {
var cfg Config
bindFramework(&cfg.Train.Framework)
return &cfg
}
func main() {
var cfg *Config = build()
data := yaml2json(data1)
var out bytes.Buffer
json.Indent(&out, data, "", " ")
out.WriteTo(os.Stdout)
println()
if err := json.Unmarshal(data, &cfg); err != nil {
log.Fatalln(err)
}
log.Printf("String: %# v\n", pretty.Formatter(cfg.Train.Framework.String))
cfg = build()
data = yaml2json(data2)
out.Reset()
json.Indent(&out, data, "", " ")
out.WriteTo(os.Stdout)
println()
if err := json.Unmarshal(data, &cfg); err != nil {
log.Fatalln(err)
}
log.Printf("framework: %# v\n", pretty.Formatter(cfg.Train.Framework.Raw))
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment