Skip to content

Instantly share code, notes, and snippets.

@TMaYaD
Created August 5, 2015 06:03
Show Gist options
  • Save TMaYaD/2de2578c4a9e4a6ffd70 to your computer and use it in GitHub Desktop.
Save TMaYaD/2de2578c4a9e4a6ffd70 to your computer and use it in GitHub Desktop.
package gorm_test
import (
"errors"
"github.com/jinzhu/gorm"
"reflect"
"testing"
"time"
)
type Parent struct {
Id int64
Children []*Child
}
type Child struct {
Id int64
ParentId int64
Parent Parent
Code string
Price int64
CreatedAt time.Time
UpdatedAt time.Time
AfterFindCallTimes int64
BeforeCreateCallTimes int64
AfterCreateCallTimes int64
BeforeUpdateCallTimes int64
AfterUpdateCallTimes int64
BeforeSaveCallTimes int64
AfterSaveCallTimes int64
BeforeDeleteCallTimes int64
AfterDeleteCallTimes int64
}
func (s *Child) BeforeCreate() (err error) {
if s.Code == "Invalid" {
err = errors.New("invalid child")
}
s.BeforeCreateCallTimes = s.BeforeCreateCallTimes + 1
return
}
func (s *Child) BeforeUpdate() (err error) {
if s.Code == "dont_update" {
err = errors.New("can't update")
}
s.BeforeUpdateCallTimes = s.BeforeUpdateCallTimes + 1
return
}
func (s *Child) BeforeSave() (err error) {
if s.Code == "dont_save" {
err = errors.New("can't save")
}
s.BeforeSaveCallTimes = s.BeforeSaveCallTimes + 1
return
}
func (s *Child) AfterFind() {
s.AfterFindCallTimes = s.AfterFindCallTimes + 1
}
func (s *Child) AfterCreate(tx *gorm.DB) {
tx.Model(s).UpdateColumn(Child{AfterCreateCallTimes: s.AfterCreateCallTimes + 1})
}
func (s *Child) AfterUpdate() {
s.AfterUpdateCallTimes = s.AfterUpdateCallTimes + 1
}
func (s *Child) AfterSave() (err error) {
if s.Code == "after_save_error" {
err = errors.New("can't save")
}
s.AfterSaveCallTimes = s.AfterSaveCallTimes + 1
return
}
func (s *Child) BeforeDelete() (err error) {
if s.Code == "dont_delete" {
err = errors.New("can't delete")
}
s.BeforeDeleteCallTimes = s.BeforeDeleteCallTimes + 1
return
}
func (s *Child) AfterDelete() (err error) {
if s.Code == "after_delete_error" {
err = errors.New("can't delete")
}
s.AfterDeleteCallTimes = s.AfterDeleteCallTimes + 1
return
}
func (s *Child) GetCallTimes() []int64 {
return []int64{s.BeforeCreateCallTimes, s.BeforeSaveCallTimes, s.BeforeUpdateCallTimes, s.AfterCreateCallTimes, s.AfterSaveCallTimes, s.AfterUpdateCallTimes, s.BeforeDeleteCallTimes, s.AfterDeleteCallTimes, s.AfterFindCallTimes}
}
func TestRunAssociationCallbacks(t *testing.T) {
DB.AutoMigrate(&Parent{}, &Child{})
DB.LogMode(true)
c := Child{Code: "unique_code", Price: 100}
p := Parent{Children: []*Child{&c}}
DB.Save(&p)
if !reflect.DeepEqual(c.GetCallTimes(), []int64{1, 1, 0, 1, 1, 0, 0, 0, 0}) {
t.Errorf("Callbacks should be invoked successfully, %v", c.GetCallTimes())
}
DB.Where("Code = ?", "unique_code").First(&p)
if !reflect.DeepEqual(c.GetCallTimes(), []int64{1, 1, 0, 1, 0, 0, 0, 0, 1}) {
t.Errorf("After callbacks values are not saved, %v", c.GetCallTimes())
}
c.Price = 200
DB.Save(&p)
if !reflect.DeepEqual(c.GetCallTimes(), []int64{1, 2, 1, 1, 1, 1, 0, 0, 1}) {
t.Errorf("After update callbacks should be invoked successfully, %v", c.GetCallTimes())
}
var children []Child
DB.Find(&children, "code = ?", "unique_code")
if children[0].AfterFindCallTimes != 2 {
t.Errorf("AfterFind callbacks should work with slice")
}
DB.Where("Code = ?", "unique_code").First(&c)
if !reflect.DeepEqual(c.GetCallTimes(), []int64{1, 2, 1, 1, 0, 0, 0, 0, 2}) {
t.Errorf("After update callbacks values are not saved, %v", c.GetCallTimes())
}
DB.Delete(&p)
if !reflect.DeepEqual(c.GetCallTimes(), []int64{1, 2, 1, 1, 0, 0, 1, 1, 2}) {
t.Errorf("After delete callbacks should be invoked successfully, %v", c.GetCallTimes())
}
if DB.Where("Code = ?", "unique_code").First(&c).Error == nil {
t.Errorf("Can't find a deleted record")
}
DB.LogMode(false)
}
func TestAssociationCallbacksWithErrors(t *testing.T) {
c := Child{Code: "Invalid", Price: 100}
p := Parent{Children: []*Child{&c}}
if DB.Save(&p).Error == nil {
t.Errorf("An error from before create callbacks happened when create with invalid value")
}
if DB.Where("code = ?", "Invalid").First(&Child{}).Error == nil {
t.Errorf("Should not save record that have errors")
}
if DB.Save(&Child{Code: "dont_save", Price: 100}).Error == nil {
t.Errorf("An error from after create callbacks happened when create with invalid value")
}
p2 := Child{Code: "update_callback", Price: 100}
DB.Save(&p2)
p2.Code = "dont_update"
if DB.Save(&p2).Error == nil {
t.Errorf("An error from before update callbacks happened when update with invalid value")
}
if DB.Where("code = ?", "update_callback").First(&Child{}).Error != nil {
t.Errorf("Record Should not be updated due to errors happened in before update callback")
}
if DB.Where("code = ?", "dont_update").First(&Child{}).Error == nil {
t.Errorf("Record Should not be updated due to errors happened in before update callback")
}
p2.Code = "dont_save"
if DB.Save(&p2).Error == nil {
t.Errorf("An error from before save callbacks happened when update with invalid value")
}
p3 := Child{Code: "dont_delete", Price: 100}
DB.Save(&p3)
if DB.Delete(&p3).Error == nil {
t.Errorf("An error from before delete callbacks happened when delete")
}
if DB.Where("Code = ?", "dont_delete").First(&p3).Error != nil {
t.Errorf("An error from before delete callbacks happened")
}
p4 := Child{Code: "after_save_error", Price: 100}
DB.Save(&p4)
if err := DB.First(&Child{}, "code = ?", "after_save_error").Error; err == nil {
t.Errorf("Record should be reverted if get an error in after save callback")
}
p5 := Child{Code: "after_delete_error", Price: 100}
DB.Save(&p5)
if err := DB.First(&Child{}, "code = ?", "after_delete_error").Error; err != nil {
t.Errorf("Record should be found")
}
DB.Delete(&p5)
if err := DB.First(&Child{}, "code = ?", "after_delete_error").Error; err != nil {
t.Errorf("Record shouldn't be deleted because of an error happened in after delete callback")
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment