Skip to content

Instantly share code, notes, and snippets.

@giautm
Last active March 7, 2018 01:43
Show Gist options
  • Save giautm/5c7289dd4d1b660859300c8fcf71ea9d to your computer and use it in GitHub Desktop.
Save giautm/5c7289dd4d1b660859300c8fcf71ea9d to your computer and use it in GitHub Desktop.
package gormrepo
import (
"context"
"reflect"
"github.com/jinzhu/gorm"
eh "github.com/looplab/eventhorizon"
"github.com/pkg/errors"
)
type GormRepository struct {
db *gorm.DB
factoryFn func() eh.Entity
}
func NewRepo(db *gorm.DB) *GormRepository {
return &GormRepository{
db: db,
}
}
// SetEntityFactory sets a factory function that creates concrete entity types.
func (r *GormRepository) SetEntityFactory(f func() eh.Entity) {
r.factoryFn = f
}
// Save implements the Save method of the WriteRepo interface.
func (r *GormRepository) Save(
ctx context.Context,
entity eh.Entity,
) error {
if err := r.db.Save(entity).Error; err != nil {
return eh.RepoError{
Err: eh.ErrCouldNotSaveEntity,
BaseErr: err,
Namespace: eh.NamespaceFromContext(ctx),
}
}
return nil
}
// Remove implements the Remove method of the WriteRepo interface.
func (r *GormRepository) Remove(
ctx context.Context,
id eh.UUID,
) (err error) {
if id == eh.UUID("") {
return eh.RepoError{
Err: eh.ErrMissingEntityID,
Namespace: eh.NamespaceFromContext(ctx),
}
}
entity := r.factoryFn()
if err := r.db.Scopes(findByID(entity, id)).Delete(entity).Error; err != nil {
return eh.RepoError{
Err: eh.ErrEntityNotFound,
BaseErr: err,
Namespace: eh.NamespaceFromContext(ctx),
}
}
return nil
}
// Parent implements the Parent method of the ReadRepo interface.
func (r *GormRepository) Parent() eh.ReadRepo {
return nil
}
// Find implements the Find method of the ReadRepo interface.
func (r *GormRepository) Find(
ctx context.Context,
id eh.UUID,
) (eh.Entity, error) {
if id == eh.UUID("") {
return nil, eh.RepoError{
Err: eh.ErrMissingEntityID,
Namespace: eh.NamespaceFromContext(ctx),
}
}
entity := r.factoryFn()
if err := r.db.Scopes(findByID(entity, id)).Find(entity).Error; err != nil {
return nil, eh.RepoError{
BaseErr: err,
Err: eh.ErrEntityNotFound,
Namespace: eh.NamespaceFromContext(ctx),
}
}
return entity, nil
}
// FindAll implements the FindAll method of the ReadRepo interface.
func (r *GormRepository) FindAll(
ctx context.Context,
) ([]eh.Entity, error) {
slicePtr := makeSlicePtr(r.factoryFn())
if err := r.db.Find(slicePtr.Interface()).Error; err != nil {
return nil, eh.RepoError{
BaseErr: err,
Err: eh.ErrEntityNotFound,
Namespace: eh.NamespaceFromContext(ctx),
}
}
slice := slicePtr.Elem()
// FIXME: Convert []interface{} to []Entity
entities := make([]eh.Entity, slice.Len())
for i := 0; i < slice.Len(); i++ {
if elem, ok := slice.Index(i).Interface().(eh.Entity); ok {
entities[i] = elem
}
}
return entities, nil
}
func findByID(value interface{}, id eh.UUID) func(db *gorm.DB) *gorm.DB {
return func(db *gorm.DB) *gorm.DB {
s := db.NewScope(value)
k := s.PrimaryKey()
return db.Where(s.Quote(k)+"= ?", id)
}
}
func makeSlicePtr(entity eh.Entity) reflect.Value {
// Create a slice to begin with
entityType := reflect.TypeOf(entity)
slice := reflect.MakeSlice(reflect.SliceOf(entityType), 0, 0)
// Create a pointer to a slice value and set it to the slice
slicePtr := reflect.New(slice.Type())
slicePtr.Elem().Set(slice)
return slicePtr
}
package gormrepo_test
import (
"context"
"fmt"
"log"
"testing"
"github.com/giautm/lms-cargo/infrastructure/repo/gormrepo"
eh "github.com/looplab/eventhorizon"
"github.com/jinzhu/gorm"
_ "github.com/jinzhu/gorm/dialects/sqlite"
)
type TestModel struct {
ID eh.UUID
Name string
}
func (t *TestModel) EntityID() eh.UUID {
return t.ID
}
func Factory() eh.Entity {
return &TestModel{}
}
func getDB(t *testing.T) *gorm.DB {
db, err := gorm.Open("sqlite3", "./gorm.db")
if err != nil {
t.Error(err)
return nil
}
db.LogMode(true)
db.AutoMigrate(&TestModel{})
return db
}
func TestFindAll(t *testing.T) {
if db := getDB(t); db != nil {
defer db.Close()
entity := &TestModel{}
entity.ID = eh.NewUUID()
entity.Name = "Giau Tran"
repo := gormrepo.NewRepo(db)
repo.SetEntityFactory(Factory)
ctx := context.Background()
if err := repo.Save(ctx, entity); err != nil {
t.Error(err)
}
entity.ID = eh.NewUUID()
if err := repo.Save(ctx, entity); err != nil {
t.Error(err)
}
if entities, err := repo.FindAll(ctx); err != nil {
t.Error(err)
} else if len(entities) != 2 {
t.Error("Number of entities should equal 2.", len(entities))
} else {
entity.ID = eh.NewUUID()
if err := repo.Save(ctx, entity); err != nil {
t.Error(err)
}
for _, a := range entities {
err := repo.Remove(ctx, a.EntityID())
if err != nil {
panic(err)
}
}
entities, _ := repo.FindAll(ctx)
log.Println(len(entities))
repo.Remove(ctx, entity.EntityID())
}
}
}
func TestFindOne(t *testing.T) {
if db := getDB(t); db != nil {
defer db.Close()
entity := &TestModel{}
entity.ID = eh.NewUUID()
entity.Name = "Giau Tran"
repo := gormrepo.NewRepo(db)
repo.SetEntityFactory(Factory)
ctx := context.Background()
if err := repo.Save(ctx, entity); err != nil {
t.Error("Save error", err)
} else {
fmt.Print(entity)
}
id := entity.EntityID()
if entity, err := repo.Find(ctx, id); err != nil {
t.Error("Find error", err)
} else if entity.EntityID() != id {
t.Error("Entity should has correct ID", id)
} else {
repo.Remove(ctx, id)
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment