Skip to content

Instantly share code, notes, and snippets.

@linxinemily
Created May 17, 2023 16:29
Show Gist options
  • Save linxinemily/276905e0145218538cb0be3a79a36153 to your computer and use it in GitHub Desktop.
Save linxinemily/276905e0145218538cb0be3a79a36153 to your computer and use it in GitHub Desktop.
package repository_test
import (
"database/sql"
"fmt"
"github.com/testcontainers/testcontainers-go"
"github.com/testcontainers/testcontainers-go/wait"
"golang.org/x/net/context"
"gorm.io/gorm/logger"
"log"
"strconv"
"time"
"github.com/golang-migrate/migrate/v4"
gomigratemysql "github.com/golang-migrate/migrate/v4/database/mysql"
_ "github.com/golang-migrate/migrate/v4/source/file"
"github.com/stretchr/testify/suite"
testcontainermysql "github.com/testcontainers/testcontainers-go/modules/mysql"
"gorm.io/driver/mysql"
"gorm.io/gorm"
"testing"
)
type MysqlRepoTestSuite struct {
suite.Suite
dbUser string
dbPassword string
dbName string
dbHost string
dbPort string
mariadbC *testcontainermysql.MySQLContainer
ctx context.Context
}
func (suite *MysqlRepoTestSuite) SetupSuite() {
suite.ctx = context.Background()
mariadbC, err := testcontainermysql.RunContainer(
suite.ctx,
testcontainers.WithImage("mariadb:10.5"),
testcontainermysql.WithPassword(suite.dbPassword),
testcontainermysql.WithDatabase(suite.dbName),
testcontainers.WithWaitStrategy(
wait.ForLog("mysqld: ready for connections.").
WithOccurrence(2).
WithStartupTimeout(2*time.Minute),
),
)
if err != nil {
log.Fatal("Failed to start MariaDB container: ", err)
}
suite.dbHost, err = mariadbC.Host(suite.ctx)
if err != nil {
log.Fatal("Failed to get MariaDB container host: ", err)
}
mappedPort, _ := mariadbC.MappedPort(suite.ctx, "3306/tcp")
suite.mariadbC = mariadbC
suite.dbPort = strconv.Itoa(mappedPort.Int())
if err != nil {
log.Fatal("Failed to get MariaDB container port: ", err)
}
_, err = suite.freshDatabase()
if err != nil {
panic(err)
}
}
func (suite *MysqlRepoTestSuite) TearDownSuite() {
err := suite.mariadbC.Terminate(suite.ctx)
if err != nil {
panic(err)
}
}
func (suite *MysqlRepoTestSuite) Run(t *testing.T, method string, fn func(t *testing.T, tx *gorm.DB)) {
t.Run(method, func(t *testing.T) {
var db *gorm.DB
var tx *gorm.DB
defer func() {
sqlDB, _ := db.DB()
sqlDB.Close()
}()
defer func() {
tx.Rollback()
}()
db = suite.openDatabaseConnection()
tx = db.Begin()
fn(t, tx)
})
}
func TestMySqlRepoTestSuite(t *testing.T) {
t.Parallel()
suite.Run(t, &MysqlRepoTestSuite{
dbUser: "root",
dbPassword: "password",
dbName: "owlnest_booking",
})
}
func (suite *MysqlRepoTestSuite) getDsn() string {
return fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?multiStatements=true&parseTime=true&charset=utf8mb4", suite.dbUser, suite.dbPassword, suite.dbHost, suite.dbPort, suite.dbName)
}
func (suite *MysqlRepoTestSuite) openDatabaseConnection() *gorm.DB {
dsn := suite.getDsn()
conn, err := gorm.Open(mysql.Open(dsn), &gorm.Config{
Logger: logger.Default.LogMode(logger.Silent),
})
if err != nil {
panic(err)
}
return conn
}
// rollback and re-run migrate
func (suite *MysqlRepoTestSuite) freshDatabase() (*migrate.Migrate, error) {
dsn := suite.getDsn()
db, err := sql.Open("mysql", dsn)
if err != nil {
return nil, err
}
driver, err := gomigratemysql.WithInstance(db, &gomigratemysql.Config{})
if err != nil {
return nil, err
}
m, err := migrate.NewWithDatabaseInstance(
"file://migrations",
"mysql",
driver,
)
if err != nil {
return nil, err
}
// rollback
err = m.Down()
if err != nil && err != migrate.ErrNoChange {
return nil, err
}
// start to migrate
if err = m.Up(); err != nil && err != migrate.ErrNoChange {
return nil, err
}
return m, nil
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment