Last active
March 7, 2018 01:43
-
-
Save giautm/5c7289dd4d1b660859300c8fcf71ea9d to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package gormrepo | |
import ( | |
"context" | |
"reflect" | |
"github.com/jinzhu/gorm" | |
eh "github.com/looplab/eventhorizon" | |
"github.com/pkg/errors" | |
) | |
type GormRepository struct { | |
db *gorm.DB | |
factoryFn func() eh.Entity | |
} | |
func NewRepo(db *gorm.DB) *GormRepository { | |
return &GormRepository{ | |
db: db, | |
} | |
} | |
// SetEntityFactory sets a factory function that creates concrete entity types. | |
func (r *GormRepository) SetEntityFactory(f func() eh.Entity) { | |
r.factoryFn = f | |
} | |
// Save implements the Save method of the WriteRepo interface. | |
func (r *GormRepository) Save( | |
ctx context.Context, | |
entity eh.Entity, | |
) error { | |
if err := r.db.Save(entity).Error; err != nil { | |
return eh.RepoError{ | |
Err: eh.ErrCouldNotSaveEntity, | |
BaseErr: err, | |
Namespace: eh.NamespaceFromContext(ctx), | |
} | |
} | |
return nil | |
} | |
// Remove implements the Remove method of the WriteRepo interface. | |
func (r *GormRepository) Remove( | |
ctx context.Context, | |
id eh.UUID, | |
) (err error) { | |
if id == eh.UUID("") { | |
return eh.RepoError{ | |
Err: eh.ErrMissingEntityID, | |
Namespace: eh.NamespaceFromContext(ctx), | |
} | |
} | |
entity := r.factoryFn() | |
if err := r.db.Scopes(findByID(entity, id)).Delete(entity).Error; err != nil { | |
return eh.RepoError{ | |
Err: eh.ErrEntityNotFound, | |
BaseErr: err, | |
Namespace: eh.NamespaceFromContext(ctx), | |
} | |
} | |
return nil | |
} | |
// Parent implements the Parent method of the ReadRepo interface. | |
func (r *GormRepository) Parent() eh.ReadRepo { | |
return nil | |
} | |
// Find implements the Find method of the ReadRepo interface. | |
func (r *GormRepository) Find( | |
ctx context.Context, | |
id eh.UUID, | |
) (eh.Entity, error) { | |
if id == eh.UUID("") { | |
return nil, eh.RepoError{ | |
Err: eh.ErrMissingEntityID, | |
Namespace: eh.NamespaceFromContext(ctx), | |
} | |
} | |
entity := r.factoryFn() | |
if err := r.db.Scopes(findByID(entity, id)).Find(entity).Error; err != nil { | |
return nil, eh.RepoError{ | |
BaseErr: err, | |
Err: eh.ErrEntityNotFound, | |
Namespace: eh.NamespaceFromContext(ctx), | |
} | |
} | |
return entity, nil | |
} | |
// FindAll implements the FindAll method of the ReadRepo interface. | |
func (r *GormRepository) FindAll( | |
ctx context.Context, | |
) ([]eh.Entity, error) { | |
slicePtr := makeSlicePtr(r.factoryFn()) | |
if err := r.db.Find(slicePtr.Interface()).Error; err != nil { | |
return nil, eh.RepoError{ | |
BaseErr: err, | |
Err: eh.ErrEntityNotFound, | |
Namespace: eh.NamespaceFromContext(ctx), | |
} | |
} | |
slice := slicePtr.Elem() | |
// FIXME: Convert []interface{} to []Entity | |
entities := make([]eh.Entity, slice.Len()) | |
for i := 0; i < slice.Len(); i++ { | |
if elem, ok := slice.Index(i).Interface().(eh.Entity); ok { | |
entities[i] = elem | |
} | |
} | |
return entities, nil | |
} | |
func findByID(value interface{}, id eh.UUID) func(db *gorm.DB) *gorm.DB { | |
return func(db *gorm.DB) *gorm.DB { | |
s := db.NewScope(value) | |
k := s.PrimaryKey() | |
return db.Where(s.Quote(k)+"= ?", id) | |
} | |
} | |
func makeSlicePtr(entity eh.Entity) reflect.Value { | |
// Create a slice to begin with | |
entityType := reflect.TypeOf(entity) | |
slice := reflect.MakeSlice(reflect.SliceOf(entityType), 0, 0) | |
// Create a pointer to a slice value and set it to the slice | |
slicePtr := reflect.New(slice.Type()) | |
slicePtr.Elem().Set(slice) | |
return slicePtr | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package gormrepo_test | |
import ( | |
"context" | |
"fmt" | |
"log" | |
"testing" | |
"github.com/giautm/lms-cargo/infrastructure/repo/gormrepo" | |
eh "github.com/looplab/eventhorizon" | |
"github.com/jinzhu/gorm" | |
_ "github.com/jinzhu/gorm/dialects/sqlite" | |
) | |
type TestModel struct { | |
ID eh.UUID | |
Name string | |
} | |
func (t *TestModel) EntityID() eh.UUID { | |
return t.ID | |
} | |
func Factory() eh.Entity { | |
return &TestModel{} | |
} | |
func getDB(t *testing.T) *gorm.DB { | |
db, err := gorm.Open("sqlite3", "./gorm.db") | |
if err != nil { | |
t.Error(err) | |
return nil | |
} | |
db.LogMode(true) | |
db.AutoMigrate(&TestModel{}) | |
return db | |
} | |
func TestFindAll(t *testing.T) { | |
if db := getDB(t); db != nil { | |
defer db.Close() | |
entity := &TestModel{} | |
entity.ID = eh.NewUUID() | |
entity.Name = "Giau Tran" | |
repo := gormrepo.NewRepo(db) | |
repo.SetEntityFactory(Factory) | |
ctx := context.Background() | |
if err := repo.Save(ctx, entity); err != nil { | |
t.Error(err) | |
} | |
entity.ID = eh.NewUUID() | |
if err := repo.Save(ctx, entity); err != nil { | |
t.Error(err) | |
} | |
if entities, err := repo.FindAll(ctx); err != nil { | |
t.Error(err) | |
} else if len(entities) != 2 { | |
t.Error("Number of entities should equal 2.", len(entities)) | |
} else { | |
entity.ID = eh.NewUUID() | |
if err := repo.Save(ctx, entity); err != nil { | |
t.Error(err) | |
} | |
for _, a := range entities { | |
err := repo.Remove(ctx, a.EntityID()) | |
if err != nil { | |
panic(err) | |
} | |
} | |
entities, _ := repo.FindAll(ctx) | |
log.Println(len(entities)) | |
repo.Remove(ctx, entity.EntityID()) | |
} | |
} | |
} | |
func TestFindOne(t *testing.T) { | |
if db := getDB(t); db != nil { | |
defer db.Close() | |
entity := &TestModel{} | |
entity.ID = eh.NewUUID() | |
entity.Name = "Giau Tran" | |
repo := gormrepo.NewRepo(db) | |
repo.SetEntityFactory(Factory) | |
ctx := context.Background() | |
if err := repo.Save(ctx, entity); err != nil { | |
t.Error("Save error", err) | |
} else { | |
fmt.Print(entity) | |
} | |
id := entity.EntityID() | |
if entity, err := repo.Find(ctx, id); err != nil { | |
t.Error("Find error", err) | |
} else if entity.EntityID() != id { | |
t.Error("Entity should has correct ID", id) | |
} else { | |
repo.Remove(ctx, id) | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment