Skip to content

Instantly share code, notes, and snippets.

@Loupax
Created June 17, 2020 11:26
Show Gist options
  • Save Loupax/b828f3b0ff46439d14c61eee014dd410 to your computer and use it in GitHub Desktop.
Save Loupax/b828f3b0ff46439d14c61eee014dd410 to your computer and use it in GitHub Desktop.
Use struct tags to get database column name
package db
import (
"database/sql"
"database/sql/driver"
"fmt"
"reflect"
"app/config"
"strings"
_ "github.com/lib/pq"
)
type DSNOption struct {
Key string
Value string
}
func NewDB(cfg config.DB) (*sql.DB, error) {
return sql.Open(
"postgres",
fmt.Sprintf(
"postgres://%s:%s@%s/%s?%s",
cfg.User,
cfg.Password,
cfg.Host,
cfg.DBName,
strings.Join(cfg.DSNOptions, "&"),
),
)
}
// ColumnDataPair describes a piece of data that is stored in a database table column
type ColumnDataPair struct {
Column string
Data interface{}
}
// GetFields returns you an array of ColumnDataPairs which describe
// a database row.
// It uses the db struct tag to get the table column names
func GetFields(s interface{}) ([]ColumnDataPair, error) {
var row []ColumnDataPair
t := reflect.TypeOf(s)
v := reflect.ValueOf(s)
for i := 0; i < t.NumField(); i++ {
field := t.Field(i)
col := field.Tag.Get("db")
if col == "" {
col = field.Name
}
val, err := driver.DefaultParameterConverter.ConvertValue(v.Field(i).Interface())
if err != nil {
return nil, err
}
row = append(row, ColumnDataPair{col, val})
}
return row, nil
}
package db
import (
"reflect"
"testing"
)
func TestGetFields(t *testing.T) {
type args struct {
s interface{}
}
tests := []struct {
name string
args args
want []ColumnDataPair
wantErr bool
}{
{
"Returns the columns in the correct order",
args{
struct {
String string `db:"column"`
SpecialID int `db:"id"`
}{
String: "whatever",
SpecialID: 1234,
},
},
[]ColumnDataPair{{"column", "whatever"}, {"id", int64(1234)}},
false,
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
got, err := GetFields(tt.args.s)
if (err != nil) != tt.wantErr {
t.Errorf("GetFields() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("GetFields() got = %v, want %v", got, tt.want)
}
})
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment