Skip to content

Instantly share code, notes, and snippets.

@leyafo
Created September 7, 2017 09:28
Show Gist options
  • Save leyafo/6c3fd4941a2aee558dcc1855f3e9e4b6 to your computer and use it in GitHub Desktop.
Save leyafo/6c3fd4941a2aee558dcc1855f3e9e4b6 to your computer and use it in GitHub Desktop.
generate go CRUD code by robots.
package main
import (
"fmt"
"os"
"path"
"text/template"
"github.com/jmoiron/sqlx"
_ "github.com/lib/pq"
"data-git.mana.com/golang/chestnut/app/utils"
)
type columnInfo struct {
DbTag string
Type string
Field string
}
type tableInfo struct {
TableName string
Columns []columnInfo
StructName string
Var string
Imports []string
}
func main() {
adapter, db_string := "", "" //config your database connection
if adapter == "" || db_string == "" {
panic("Please config your database connection before!!!!")
}
db, err := sqlx.Connect(adapter, db_string)
if err != nil {
panic(err.Error())
}
tables, _ := db.Queryx(`select tablename from pg_tables where schemaname = 'public' and tablename != 'schema_migrations'`)
for tables.Next() {
var ti tableInfo
tables.Scan(&ti.TableName)
ti.StructName = utils.ToCamelCase(utils.Singular(ti.TableName))
ti.Var = ti.TableName[0:1]
query := `SELECT column_name,data_type FROM information_schema.columns WHERE table_name=$1 AND table_schema='public';`
rows, _ := db.Queryx(query, ti.TableName)
for rows.Next() {
var ci columnInfo
var col, colType string
rows.Scan(&col, &colType)
l := len(col)
if col == "id" || (l > 3 && col[l-3:] == "_id") {
ci.Type = "uint64"
} else {
ci.Type = mappingDbType(colType)
}
ci.Field = utils.ToCamelCase(col)
ci.DbTag = col
ti.Columns = append(ti.Columns, ci)
ti.Imports = appendImport(ti.Imports, ci.Type)
}
f, _ := os.Create(path.Join("./models", utils.Singular(ti.TableName)+"_gen.go"))
genCode(f, ti)
}
}
func appendImport(imports []string, typ string) []string {
var i string
switch typ {
case "pq.StringArray":
i = "github.com/lib/pq"
case "decimal.Decimal":
i = "github.com/shopspring/decimal"
case "hstore.Hstore":
i = "github.com/lib/pq/hstore"
case "uuid.UUID":
i = "uuid"
default:
i = ""
}
if i != "" {
for _, im := range imports {
if im == i {
return imports
}
}
imports = append(imports, i)
}
return imports
}
func genCode(f *os.File, ti tableInfo) string {
tmpl := template.New("table.go.tpl").Funcs(template.FuncMap{"insertStr": insertStr, "updateStr": updateStr})
tmpl, err := tmpl.ParseFiles("./table.go.tpl")
if err != nil {
panic(err.Error())
}
err = tmpl.Execute(f, ti)
if err != nil {
panic(err.Error())
}
return ""
}
func insertStr(columns []columnInfo) string {
var colNames, colValues string
colNames = "("
colValues = "("
notFirst := false
for _, ci := range columns {
if ci.DbTag == "id" {
continue
}
if notFirst {
colNames += ","
colValues += ","
}
colNames += ci.DbTag
colValues += ":" + ci.DbTag
notFirst = true
}
colNames += ")"
colValues += ")"
return colNames + " VALUES " + colValues
}
func updateStr(columns []columnInfo) string {
var s string
s = "SET "
notFirst := false
for _, ci := range columns {
if ci.DbTag == "id" {
continue
}
if notFirst {
s += ","
}
s += ci.DbTag + "=:" + ci.DbTag
notFirst = true
}
return s
}
func mappingDbType(dbType string) string {
var typ string
switch dbType {
case "boolean":
typ = "bool"
case "character", "character varying", "text", "money", "inet":
typ = "string"
case "ARRAY":
typ = "pq.StringArray"
case "smallint":
typ = "int16"
case "integer":
typ = "int32"
case "bigint":
typ = "int64"
case "smallserial":
typ = "uint16"
case "serial", "bigserial":
typ = "uint64"
case "real", "numeric", "double precision":
typ = "decimal.Decimal"
case "bytea":
typ = "byte"
case "date", "timestamp with time zone", "time with time zone", "time without time zone", "timestamp without time zone":
typ = "time.Time"
case "interval":
typ = "*time.Duration"
case `"char"`, "bit":
typ = "uint8"
case `"any"`, "bit varying":
typ = "byte"
case "hstore", "USER-DEFINED":
typ = "hstore.Hstore"
case "uuid":
typ = "uuid.UUID"
default:
panic(fmt.Sprintf("Cannot convert type: %s", dbType))
}
return typ
}
//this file generated by Robots, Don't Edit!!!!!!
package models
import(
"time"
"errors"
{{- range .Imports}}
"{{.}}"
{{- end}}
)
type {{.StructName}} struct{
{{- range .Columns}}
{{.Field}} {{.Type}} `db:"{{.DbTag}}"`
{{- end}}
_exists bool `db:"-"`
_deleted bool `db:"-"`
}
func ({{.Var}} {{.StructName}}) TableName() string{
return "{{.TableName}}"
}
func ({{.Var}} *{{.StructName}}) Exists() bool{
return {{.Var}}._exists
}
func ({{.Var}} *{{.StructName}}) Deleted() bool{
return {{.Var}}._deleted
}
func ({{.Var}} *{{.StructName}}) Insert() error{
var err error
if {{.Var}}._exists {
return errors.New("insert failed: already exists")
}
{{.Var}}.CreatedAt = time.Now()
{{.Var}}.UpdatedAt = {{.Var}}.CreatedAt
//sql query
const sqlstr = `
INSERT INTO {{.TableName}} {{insertStr .Columns}} RETURNING id
`
tx := DB.MustBegin()
stmt, err := tx.PrepareNamed(sqlstr)
if err != nil{
tx.Rollback()
return err
}
err = stmt.QueryRowx(&{{.Var}}).Scan(&{{.Var}}.Id)
if err != nil{
tx.Rollback()
return err
}
tx.Commit()
{{.Var}}._exists = true
return nil
}
func ({{.Var}} *{{.StructName}}) Update() error{
var err error
if !{{.Var}}._exists {
return errors.New("Update failed: does not exists")
}
if {{.Var}}._deleted {
return errors.New("Update failed: marked for deletion")
}
{{.Var}}.UpdatedAt = time.Now()
const sqlstr = `
UPDATE {{.TableName}} ({{updateStr .Columns}}) where id=:id
`
tx := DB.MustBegin()
_, err = tx.NamedExec(sqlstr,&{{.Var}})
if err != nil{
tx.Rollback()
return err
}
tx.Commit()
return nil
}
func ({{.Var}} *{{.StructName}}) Delete() error{
if !{{.Var}}._exists {
return errors.New("Delete failed: does not exists")
}
if {{.Var}}._deleted {
return nil
}
const sqlstr = `
DELETE FROM {{.TableName}} where id= $1
`
tx := DB.MustBegin()
_,err := tx.Queryx(sqlstr, &{{.Var}}.Id)
if err != nil{
tx.Rollback()
return err
}
tx.Commit()
{{.Var}}._deleted = true
return nil
}
func ({{.Var}} *{{.StructName}}) Upsert() error{
if {{.Var}}._exists {
return errors.New("Upsert failed: already exists")
}
if {{.Var}}._deleted {
return errors.New("Update failed: marked for deletion")
}
if {{.Var}}.CreatedAt.IsZero() {
{{.Var}}.CreatedAt = time.Now()
}
{{.Var}}.UpdatedAt = time.Now()
const sqlstr = `
INSERT INTO {{.TableName}} {{insertStr .Columns}}
ON CONFLICT id DO UPDATE ({{updateStr .Columns}})
`
tx := DB.MustBegin()
_, err := tx.NamedExec(sqlstr,&{{.Var}})
if err != nil{
tx.Rollback()
return err
}
tx.Commit()
return nil
}
func ({{.Var}} *{{.StructName}}) Save() error{
if {{.Var}}._exists {
return {{.Var}}.Update()
}
return {{.Var}}.Insert()
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment