Skip to content

Instantly share code, notes, and snippets.

@zhnxin
Last active March 12, 2020 04:29
Show Gist options
  • Save zhnxin/8aef8c1bd45382e7e4211409d9dc9431 to your computer and use it in GitHub Desktop.
Save zhnxin/8aef8c1bd45382e7e4211409d9dc9431 to your computer and use it in GitHub Desktop.
package foreignkey
import (
"errors"
"fmt"
"reflect"
"regexp"
"strings"
)
const (
MysqlRreferenceOption_RESTRICT MysqlRreferenceOption = iota
MysqlRreferenceOption_CASCADE
MysqlRreferenceOption_SET_NULL
MysqlRreferenceOption_NO_ACTION
MysqlRreferenceOption_SET_DEFAULT
)
var (
mysqlForeignkeyPattern = regexp.MustCompile("CONSTRAINT `(.*?)` FOREIGN KEY")
mysqlRreferenceOptionMap = map[string]MysqlRreferenceOption{
"RESTRICT": MysqlRreferenceOption_RESTRICT,
"CASCADE": MysqlRreferenceOption_CASCADE,
"SET_NULL": MysqlRreferenceOption_SET_NULL,
"NO_ACTION": MysqlRreferenceOption_NO_ACTION,
"SET_DEFAULT": MysqlRreferenceOption_SET_DEFAULT,
}
mysqlRreferenceOption = []string{"RESTRICT", "CASCADE", "SET NULL", "NO ACTION", "SET DEFAULT"}
tagHandlers map[string]tagHandler = map[string]tagHandler{
"UPDATE": func(opt *MysqlForeignKeyOption) error {
return opt.HandleUpdateRef()
},
"DELETE": func(opt *MysqlForeignKeyOption) error {
return opt.HandleDeleteRef()
},
"REFERENCES": func(opt *MysqlForeignKeyOption) error {
return opt.HandleReferences()
},
"KEY": func(opt *MysqlForeignKeyOption) error {
return opt.HandleKey()
},
}
)
type (
tableStructWithTableName interface {
TableName() string
}
tagHandler func(*MysqlForeignKeyOption) error
MysqlRreferenceOption uint8
MysqlForeignKeyOption struct {
Table string
Name string
Key string
ReferedTable string
ReferedKey string
OnUpdate MysqlRreferenceOption
OnDelete MysqlRreferenceOption
tagName string
params []string
preTag, nextTag string
ignoreNext bool
}
)
func (opt *MysqlForeignKeyOption) HandleReferences() error {
if len(opt.params) != 2 {
return fmt.Errorf("REFERENCES(table,key) is needed")
}
opt.ReferedTable = strings.ReplaceAll(opt.params[0], "'", "")
opt.ReferedKey = strings.ReplaceAll(opt.params[1], "'", "")
if opt.ReferedTable == "" || opt.ReferedKey == "" {
return fmt.Errorf("REFERENCES(table,key) not be NULL")
}
return nil
}
func (opt *MysqlForeignKeyOption) handleRefOption() string {
var p string
if len(opt.params) < 1 {
p = opt.nextTag
opt.ignoreNext = true
} else {
p = opt.params[0]
}
return strings.ReplaceAll(p, "'", "")
}
func (opt *MysqlForeignKeyOption) HandleDeleteRef() error {
p := opt.handleRefOption()
var ok bool
opt.OnDelete, ok = mysqlRreferenceOptionMap[p]
if !ok {
return fmt.Errorf("unknow reference_option:%v", p)
}
return nil
}
func (opt *MysqlForeignKeyOption) HandleUpdateRef() error {
p := opt.handleRefOption()
var ok bool
opt.OnUpdate, ok = mysqlRreferenceOptionMap[p]
if !ok {
return fmt.Errorf("unknow reference_option:%v", p)
}
return nil
}
func (opt *MysqlForeignKeyOption) HandleKey() error {
if len(opt.params) > 0 {
opt.Key = strings.ReplaceAll(opt.params[0], "'", "")
} else {
return fmt.Errorf("key(key_name) not be NULL")
}
return nil
}
func (o *MysqlForeignKeyOption) AddConstraintSql() string {
return fmt.Sprintf("ALTER TABLE `%s` ADD CONSTRAINT `%s` FOREIGN KEY (`%s`) REFERENCES `%s` (`%s`) ON UPDATE %s ON DELETE %s;",
o.Table, o.Name, o.Key, o.ReferedTable, o.ReferedKey, mysqlRreferenceOption[int(o.OnUpdate)], mysqlRreferenceOption[int(o.OnDelete)])
}
func (o *MysqlForeignKeyOption) DropConstraintSql() string {
return fmt.Sprintf("ALTER TABLE `%s` DROP FOREIGN KEY IF EXISTS `%s`;",
o.Table, o.Name)
}
func GetFroeignkeyConfig(bean tableStructWithTableName) ([]*MysqlForeignKeyOption, error) {
t := reflect.TypeOf(bean)
if t.Kind() != reflect.Struct {
t = t.Elem()
}
if t.Kind() != reflect.Struct {
t = t.Elem()
}
options := []*MysqlForeignKeyOption{}
for i := 0; i < t.NumField(); i++ {
fieldValue := t.Field(i)
ormTagStr, ok := fieldValue.Tag.Lookup("foreign_key")
if !ok {
continue
}
tags := splitTag(ormTagStr)
if len(tags) > 0 {
if tags[0] == "-" {
continue
}
var opt = MysqlForeignKeyOption{
Table: bean.TableName(),
Key: fieldValue.Name,
}
for j, key := range tags {
if opt.ignoreNext {
opt.ignoreNext = false
continue
}
k := strings.ToUpper(key)
opt.tagName = k
opt.params = []string{}
pStart := strings.Index(k, "(")
if pStart == 0 {
return nil, errors.New("( could not be the first charactor")
}
if pStart > -1 {
if !strings.HasSuffix(k, ")") {
return nil, fmt.Errorf("field %s tag %s cannot match ) charactor", opt.Key, key)
}
opt.tagName = k[:pStart]
opt.params = strings.Split(key[pStart+1:len(k)-1], ",")
}
if j > 0 {
opt.preTag = strings.ToUpper(tags[j-1])
}
if j < len(tags)-1 {
opt.nextTag = tags[j+1]
} else {
opt.nextTag = ""
}
if h, ok := tagHandlers[opt.tagName]; ok {
if err := h(&opt); err != nil {
return nil, err
}
} else {
if strings.HasPrefix(key, "'") && strings.HasSuffix(key, "'") {
opt.Name = key[1 : len(key)-1]
} else {
opt.Name = key
}
}
}
options = append(options, &opt)
}
}
return options, nil
}
func splitTag(tag string) (tags []string) {
tag = strings.TrimSpace(tag)
var hasQuote = false
var lastIdx = 0
for i, t := range tag {
if t == '\'' {
hasQuote = !hasQuote
} else if t == ' ' {
if lastIdx < i && !hasQuote {
tags = append(tags, strings.TrimSpace(tag[lastIdx:i]))
lastIdx = i + 1
}
}
}
if lastIdx < len(tag) {
tags = append(tags, strings.TrimSpace(tag[lastIdx:]))
}
return
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment